mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
refactor: move src and tests from lib/crewai to root
This commit is contained in:
0
tests/rag/__init__.py
Normal file
0
tests/rag/__init__.py
Normal file
0
tests/rag/chromadb/__init__.py
Normal file
0
tests/rag/chromadb/__init__.py
Normal file
774
tests/rag/chromadb/test_client.py
Normal file
774
tests/rag/chromadb/test_client.py
Normal file
@@ -0,0 +1,774 @@
|
||||
"""Tests for ChromaDBClient implementation."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
from crewai.rag.types import BaseRecord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chromadb_client():
|
||||
"""Create a mock ChromaDB client."""
|
||||
from chromadb.api import ClientAPI
|
||||
|
||||
return Mock(spec=ClientAPI)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_chromadb_client():
|
||||
"""Create a mock async ChromaDB client."""
|
||||
from chromadb.api import AsyncClientAPI
|
||||
|
||||
return Mock(spec=AsyncClientAPI)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance for testing."""
|
||||
mock_embedding = Mock()
|
||||
client = ChromaDBClient(
|
||||
client=mock_chromadb_client, embedding_function=mock_embedding
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_with_batch_size(mock_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance with custom batch size for testing."""
|
||||
mock_embedding = Mock()
|
||||
client = ChromaDBClient(
|
||||
client=mock_chromadb_client,
|
||||
embedding_function=mock_embedding,
|
||||
default_batch_size=2,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client_with_batch_size(mock_async_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance with async client and custom batch size for testing."""
|
||||
mock_embedding = Mock()
|
||||
client = ChromaDBClient(
|
||||
client=mock_async_chromadb_client,
|
||||
embedding_function=mock_embedding,
|
||||
default_batch_size=2,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance with async client for testing."""
|
||||
mock_embedding = Mock()
|
||||
client = ChromaDBClient(
|
||||
client=mock_async_chromadb_client, embedding_function=mock_embedding
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class TestChromaDBClient:
|
||||
"""Test suite for ChromaDBClient."""
|
||||
|
||||
def test_create_collection(self, client, mock_chromadb_client):
|
||||
"""Test that create_collection calls the underlying client correctly."""
|
||||
client.create_collection(collection_name="test_collection")
|
||||
|
||||
mock_chromadb_client.create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=None,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
embedding_function=client.embedding_function,
|
||||
data_loader=None,
|
||||
get_or_create=False,
|
||||
)
|
||||
|
||||
def test_create_collection_with_all_params(self, client, mock_chromadb_client):
|
||||
"""Test create_collection with all optional parameters."""
|
||||
mock_config = Mock()
|
||||
mock_metadata = {"key": "value"}
|
||||
mock_embedding_func = Mock()
|
||||
mock_data_loader = Mock()
|
||||
|
||||
client.create_collection(
|
||||
collection_name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
get_or_create=True,
|
||||
)
|
||||
|
||||
mock_chromadb_client.create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
get_or_create=True,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test that acreate_collection calls the underlying client correctly."""
|
||||
# Make the mock's create_collection an AsyncMock
|
||||
mock_async_chromadb_client.create_collection = AsyncMock(return_value=None)
|
||||
|
||||
await async_client.acreate_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_chromadb_client.create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=None,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
embedding_function=async_client.embedding_function,
|
||||
data_loader=None,
|
||||
get_or_create=False,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection_with_all_params(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test acreate_collection with all optional parameters."""
|
||||
# Make the mock's create_collection an AsyncMock
|
||||
mock_async_chromadb_client.create_collection = AsyncMock(return_value=None)
|
||||
|
||||
mock_config = Mock()
|
||||
mock_metadata = {"key": "value"}
|
||||
mock_embedding_func = Mock()
|
||||
mock_data_loader = Mock()
|
||||
|
||||
await async_client.acreate_collection(
|
||||
collection_name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
get_or_create=True,
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
get_or_create=True,
|
||||
)
|
||||
|
||||
def test_get_or_create_collection(self, client, mock_chromadb_client):
|
||||
"""Test that get_or_create_collection calls the underlying client correctly."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
result = client.get_or_create_collection(collection_name="test_collection")
|
||||
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=None,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
embedding_function=client.embedding_function,
|
||||
data_loader=None,
|
||||
)
|
||||
assert result == mock_collection
|
||||
|
||||
def test_get_or_create_collection_with_all_params(
|
||||
self, client, mock_chromadb_client
|
||||
):
|
||||
"""Test get_or_create_collection with all optional parameters."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
mock_config = Mock()
|
||||
mock_metadata = {"key": "value"}
|
||||
mock_embedding_func = Mock()
|
||||
mock_data_loader = Mock()
|
||||
|
||||
result = client.get_or_create_collection(
|
||||
collection_name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
)
|
||||
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
)
|
||||
assert result == mock_collection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test that aget_or_create_collection calls the underlying client correctly."""
|
||||
mock_collection = Mock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
result = await async_client.aget_or_create_collection(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=None,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
embedding_function=async_client.embedding_function,
|
||||
data_loader=None,
|
||||
)
|
||||
assert result == mock_collection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection_with_all_params(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aget_or_create_collection with all optional parameters."""
|
||||
mock_collection = Mock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_config = Mock()
|
||||
mock_metadata = {"key": "value"}
|
||||
mock_embedding_func = Mock()
|
||||
mock_data_loader = Mock()
|
||||
|
||||
result = await async_client.aget_or_create_collection(
|
||||
collection_name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
configuration=mock_config,
|
||||
metadata=mock_metadata,
|
||||
embedding_function=mock_embedding_func,
|
||||
data_loader=mock_data_loader,
|
||||
)
|
||||
assert result == mock_collection
|
||||
|
||||
def test_add_documents(self, client, mock_chromadb_client) -> None:
|
||||
"""Test that add_documents adds documents to collection."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=client.embedding_function,
|
||||
)
|
||||
|
||||
# Verify documents were added to collection
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert len(call_args.kwargs["ids"]) == 1
|
||||
assert call_args.kwargs["documents"] == ["Test document"]
|
||||
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
||||
|
||||
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with custom document IDs."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "custom_id_1",
|
||||
"content": "First document",
|
||||
"metadata": {"source": "test1"},
|
||||
},
|
||||
{
|
||||
"doc_id": "custom_id_2",
|
||||
"content": "Second document",
|
||||
"metadata": {"source": "test2"},
|
||||
},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_collection.upsert.assert_called_once_with(
|
||||
ids=["custom_id_1", "custom_id_2"],
|
||||
documents=["First document", "Second document"],
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with documents that have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
{"content": "Another document", "metadata": None},
|
||||
{"content": "Document with metadata", "metadata": {"key": "value"}},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
# Verify upsert was called with empty dicts for missing metadata
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
|
||||
|
||||
def test_add_documents_all_without_metadata(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test add_documents when all documents have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document 1"},
|
||||
{"content": "Document 2"},
|
||||
{"content": "Document 3"},
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] is None
|
||||
|
||||
def test_add_documents_empty_list_raises_error(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test that add_documents raises error for empty documents list."""
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
client.add_documents(collection_name="test_collection", documents=[])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test that aadd_documents adds documents to collection asynchronously."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=async_client.embedding_function,
|
||||
)
|
||||
|
||||
# Verify documents were added to collection
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert len(call_args.kwargs["ids"]) == 1
|
||||
assert call_args.kwargs["documents"] == ["Test document"]
|
||||
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_with_custom_ids(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with custom document IDs."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "custom_id_1",
|
||||
"content": "First document",
|
||||
"metadata": {"source": "test1"},
|
||||
},
|
||||
{
|
||||
"doc_id": "custom_id_2",
|
||||
"content": "Second document",
|
||||
"metadata": {"source": "test2"},
|
||||
},
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
mock_collection.upsert.assert_called_once_with(
|
||||
ids=["custom_id_1", "custom_id_2"],
|
||||
documents=["First document", "Second document"],
|
||||
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_without_metadata(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with documents that have no metadata."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
{"content": "Another document", "metadata": None},
|
||||
{"content": "Document with metadata", "metadata": {"key": "value"}},
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
# Verify upsert was called with empty dicts for missing metadata
|
||||
mock_collection.upsert.assert_called_once()
|
||||
call_args = mock_collection.upsert.call_args
|
||||
assert call_args[1]["metadatas"] == [{}, {}, {"key": "value"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_empty_list_raises_error(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test that aadd_documents raises error for empty documents list."""
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=[]
|
||||
)
|
||||
|
||||
def test_search(self, client, mock_chromadb_client):
|
||||
"""Test that search queries the collection correctly."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["doc1", "doc2"]],
|
||||
"documents": [["Document 1", "Document 2"]],
|
||||
"metadatas": [[{"source": "test1"}, {"source": "test2"}]],
|
||||
"distances": [[0.1, 0.3]],
|
||||
}
|
||||
|
||||
results = client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=client.embedding_function,
|
||||
)
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=5,
|
||||
where=None,
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["content"] == "Document 1"
|
||||
assert results[0]["metadata"] == {"source": "test1"}
|
||||
assert results[0]["score"] == 0.95
|
||||
|
||||
def test_search_with_optional_params(self, client, mock_chromadb_client):
|
||||
"""Test search with optional parameters."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["doc1", "doc2", "doc3"]],
|
||||
"documents": [["Document 1", "Document 2", "Document 3"]],
|
||||
"metadatas": [
|
||||
[{"source": "test1"}, {"source": "test2"}, {"source": "test3"}]
|
||||
],
|
||||
"distances": [[0.1, 0.3, 1.5]], # Last one will be filtered by threshold
|
||||
}
|
||||
|
||||
results = client.search(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
limit=5,
|
||||
metadata_filter={"source": "test"},
|
||||
score_threshold=0.7,
|
||||
)
|
||||
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=5,
|
||||
where={"source": "test"},
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, async_client, mock_async_chromadb_client) -> None:
|
||||
"""Test that asearch queries the collection correctly."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_collection.query = AsyncMock(
|
||||
return_value={
|
||||
"ids": [["doc1", "doc2"]],
|
||||
"documents": [["Document 1", "Document 2"]],
|
||||
"metadatas": [[{"source": "test1"}, {"source": "test2"}]],
|
||||
"distances": [[0.1, 0.3]],
|
||||
}
|
||||
)
|
||||
|
||||
results = await async_client.asearch(
|
||||
collection_name="test_collection", query="test query"
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=async_client.embedding_function,
|
||||
)
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=5,
|
||||
where=None,
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["content"] == "Document 1"
|
||||
assert results[0]["metadata"] == {"source": "test1"}
|
||||
assert results[0]["score"] == 0.95
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_with_optional_params(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test asearch with optional parameters."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_collection.query = AsyncMock(
|
||||
return_value={
|
||||
"ids": [["doc1", "doc2", "doc3"]],
|
||||
"documents": [["Document 1", "Document 2", "Document 3"]],
|
||||
"metadatas": [
|
||||
[{"source": "test1"}, {"source": "test2"}, {"source": "test3"}]
|
||||
],
|
||||
"distances": [
|
||||
[0.1, 0.3, 1.5]
|
||||
], # Last one will be filtered by threshold
|
||||
}
|
||||
)
|
||||
|
||||
results = await async_client.asearch(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
limit=5,
|
||||
metadata_filter={"source": "test"},
|
||||
score_threshold=0.7,
|
||||
)
|
||||
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=5,
|
||||
where={"source": "test"},
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
# Only 2 results should pass the score threshold
|
||||
assert len(results) == 2
|
||||
|
||||
def test_delete_collection(self, client, mock_chromadb_client):
|
||||
"""Test that delete_collection calls the underlying client correctly."""
|
||||
client.delete_collection(collection_name="test_collection")
|
||||
|
||||
mock_chromadb_client.delete_collection.assert_called_once_with(
|
||||
name="test_collection"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adelete_collection(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test that adelete_collection calls the underlying client correctly."""
|
||||
mock_async_chromadb_client.delete_collection = AsyncMock(return_value=None)
|
||||
|
||||
await async_client.adelete_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_chromadb_client.delete_collection.assert_called_once_with(
|
||||
name="test_collection"
|
||||
)
|
||||
|
||||
def test_reset(self, client, mock_chromadb_client):
|
||||
"""Test that reset calls the underlying client correctly."""
|
||||
mock_chromadb_client.reset.return_value = True
|
||||
|
||||
client.reset()
|
||||
|
||||
mock_chromadb_client.reset.assert_called_once_with()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset(self, async_client, mock_async_chromadb_client) -> None:
|
||||
"""Test that areset calls the underlying client correctly."""
|
||||
mock_async_chromadb_client.reset = AsyncMock(return_value=True)
|
||||
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_chromadb_client.reset.assert_called_once_with()
|
||||
|
||||
def test_add_documents_with_batch_size(
|
||||
self, client_with_batch_size, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test add_documents with batch size splits documents into batches."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
|
||||
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
|
||||
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
|
||||
{"doc_id": "id4", "content": "Document 4", "metadata": {"source": "test4"}},
|
||||
{"doc_id": "id5", "content": "Document 5", "metadata": {"source": "test5"}},
|
||||
]
|
||||
|
||||
client_with_batch_size.add_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 3
|
||||
|
||||
first_call = mock_collection.upsert.call_args_list[0]
|
||||
assert first_call.kwargs["ids"] == ["id1", "id2"]
|
||||
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
|
||||
assert first_call.kwargs["metadatas"] == [
|
||||
{"source": "test1"},
|
||||
{"source": "test2"},
|
||||
]
|
||||
|
||||
second_call = mock_collection.upsert.call_args_list[1]
|
||||
assert second_call.kwargs["ids"] == ["id3", "id4"]
|
||||
assert second_call.kwargs["documents"] == ["Document 3", "Document 4"]
|
||||
assert second_call.kwargs["metadatas"] == [
|
||||
{"source": "test3"},
|
||||
{"source": "test4"},
|
||||
]
|
||||
|
||||
third_call = mock_collection.upsert.call_args_list[2]
|
||||
assert third_call.kwargs["ids"] == ["id5"]
|
||||
assert third_call.kwargs["documents"] == ["Document 5"]
|
||||
assert third_call.kwargs["metadatas"] == [{"source": "test5"}]
|
||||
|
||||
def test_add_documents_with_explicit_batch_size(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test add_documents with explicitly provided batch size."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1"},
|
||||
{"doc_id": "id2", "content": "Document 2"},
|
||||
{"doc_id": "id3", "content": "Document 3"},
|
||||
]
|
||||
|
||||
client.add_documents(
|
||||
collection_name="test_collection", documents=documents, batch_size=1
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 3
|
||||
for i, call in enumerate(mock_collection.upsert.call_args_list):
|
||||
assert len(call.kwargs["ids"]) == 1
|
||||
assert call.kwargs["ids"] == [f"id{i + 1}"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_with_batch_size(
|
||||
self, async_client_with_batch_size, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with batch size splits documents into batches."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
|
||||
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
|
||||
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
|
||||
]
|
||||
|
||||
await async_client_with_batch_size.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 2
|
||||
|
||||
first_call = mock_collection.upsert.call_args_list[0]
|
||||
assert first_call.kwargs["ids"] == ["id1", "id2"]
|
||||
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
|
||||
|
||||
second_call = mock_collection.upsert.call_args_list[1]
|
||||
assert second_call.kwargs["ids"] == ["id3"]
|
||||
assert second_call.kwargs["documents"] == ["Document 3"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_with_explicit_batch_size(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with explicitly provided batch size."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1"},
|
||||
{"doc_id": "id2", "content": "Document 2"},
|
||||
{"doc_id": "id3", "content": "Document 3"},
|
||||
{"doc_id": "id4", "content": "Document 4"},
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents, batch_size=3
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 2
|
||||
|
||||
first_call = mock_collection.upsert.call_args_list[0]
|
||||
assert len(first_call.kwargs["ids"]) == 3
|
||||
|
||||
second_call = mock_collection.upsert.call_args_list[1]
|
||||
assert len(second_call.kwargs["ids"]) == 1
|
||||
|
||||
def test_client_default_batch_size_initialization(self) -> None:
|
||||
"""Test that client initializes with correct default batch size."""
|
||||
mock_client = Mock()
|
||||
mock_embedding = Mock()
|
||||
|
||||
client = ChromaDBClient(client=mock_client, embedding_function=mock_embedding)
|
||||
assert client.default_batch_size == 100
|
||||
|
||||
custom_client = ChromaDBClient(
|
||||
client=mock_client, embedding_function=mock_embedding, default_batch_size=50
|
||||
)
|
||||
assert custom_client.default_batch_size == 50
|
||||
302
tests/rag/chromadb/test_utils.py
Normal file
302
tests/rag/chromadb/test_utils.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""Tests for ChromaDB utility functions."""
|
||||
|
||||
from crewai.rag.chromadb.types import PreparedDocuments
|
||||
from crewai.rag.chromadb.utils import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
_create_batch_slice,
|
||||
_is_ipv4_pattern,
|
||||
_prepare_documents_for_chromadb,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
from crewai.rag.types import BaseRecord
|
||||
|
||||
|
||||
class TestChromaDBUtils:
|
||||
"""Test suite for ChromaDB utility functions."""
|
||||
|
||||
def test_sanitize_collection_name_long_name(self) -> None:
|
||||
"""Test sanitizing a very long collection name."""
|
||||
long_name = "This is an extremely long role name that will definitely exceed the ChromaDB collection name limit of 63 characters and cause an error when used as a collection name"
|
||||
sanitized = _sanitize_collection_name(long_name)
|
||||
assert len(sanitized) <= MAX_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_sanitize_collection_name_special_chars(self) -> None:
|
||||
"""Test sanitizing a name with special characters."""
|
||||
special_chars = "Agent@123!#$%^&*()"
|
||||
sanitized = _sanitize_collection_name(special_chars)
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_sanitize_collection_name_short_name(self) -> None:
|
||||
"""Test sanitizing a very short name."""
|
||||
short_name = "A"
|
||||
sanitized = _sanitize_collection_name(short_name)
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_bad_ends(self) -> None:
|
||||
"""Test sanitizing a name with non-alphanumeric start/end."""
|
||||
bad_ends = "_Agent_"
|
||||
sanitized = _sanitize_collection_name(bad_ends)
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_none(self) -> None:
|
||||
"""Test sanitizing a None value."""
|
||||
sanitized = _sanitize_collection_name(None)
|
||||
assert sanitized == "default_collection"
|
||||
|
||||
def test_sanitize_collection_name_ipv4_pattern(self) -> None:
|
||||
"""Test sanitizing an IPv4 address."""
|
||||
ipv4 = "192.168.1.1"
|
||||
sanitized = _sanitize_collection_name(ipv4)
|
||||
assert sanitized.startswith("ip_")
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
assert all(c.isalnum() or c in ["_", "-"] for c in sanitized)
|
||||
|
||||
def test_is_ipv4_pattern(self) -> None:
|
||||
"""Test IPv4 pattern detection."""
|
||||
assert _is_ipv4_pattern("192.168.1.1") is True
|
||||
assert _is_ipv4_pattern("not.an.ip.address") is False
|
||||
|
||||
def test_sanitize_collection_name_properties(self) -> None:
|
||||
"""Test that sanitized collection names always meet ChromaDB requirements."""
|
||||
test_cases: list[str] = [
|
||||
"A" * 100, # Very long name
|
||||
"_start_with_underscore",
|
||||
"end_with_underscore_",
|
||||
"contains@special#characters",
|
||||
"192.168.1.1", # IPv4 address
|
||||
"a" * 2, # Too short
|
||||
]
|
||||
for test_case in test_cases:
|
||||
sanitized = _sanitize_collection_name(test_case)
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert len(sanitized) <= MAX_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
def test_sanitize_collection_name_empty_string(self) -> None:
|
||||
"""Test sanitizing an empty string."""
|
||||
sanitized = _sanitize_collection_name("")
|
||||
assert sanitized == "default_collection"
|
||||
|
||||
def test_sanitize_collection_name_whitespace_only(self) -> None:
|
||||
"""Test sanitizing a string with only whitespace."""
|
||||
sanitized = _sanitize_collection_name(" ")
|
||||
assert (
|
||||
sanitized == "a__z"
|
||||
) # Spaces become underscores, padded to meet requirements
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
|
||||
class TestPrepareDocumentsForChromaDB:
|
||||
"""Test suite for _prepare_documents_for_chromadb function."""
|
||||
|
||||
def test_prepare_documents_with_doc_ids(self) -> None:
|
||||
"""Test preparing documents that already have doc_ids."""
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "id1",
|
||||
"content": "First document",
|
||||
"metadata": {"source": "test1"},
|
||||
},
|
||||
{
|
||||
"doc_id": "id2",
|
||||
"content": "Second document",
|
||||
"metadata": {"source": "test2"},
|
||||
},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert result.ids == ["id1", "id2"]
|
||||
assert result.texts == ["First document", "Second document"]
|
||||
assert result.metadatas == [{"source": "test1"}, {"source": "test2"}]
|
||||
|
||||
def test_prepare_documents_generate_ids(self) -> None:
|
||||
"""Test preparing documents without doc_ids (should generate hashes)."""
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Test content", "metadata": {"key": "value"}},
|
||||
{"content": "Another test"},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert len(result.ids) == 2
|
||||
assert all(len(doc_id) == 64 for doc_id in result.ids)
|
||||
assert result.texts == ["Test content", "Another test"]
|
||||
assert result.metadatas == [{"key": "value"}, {}]
|
||||
|
||||
def test_prepare_documents_with_list_metadata(self) -> None:
|
||||
"""Test preparing documents with list metadata (should take first item)."""
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Test", "metadata": [{"first": "item"}, {"second": "item"}]},
|
||||
{"content": "Test2", "metadata": []},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert result.metadatas == [{"first": "item"}, {}]
|
||||
|
||||
def test_prepare_documents_no_metadata(self) -> None:
|
||||
"""Test preparing documents without metadata."""
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document 1"},
|
||||
{"content": "Document 2", "metadata": None},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert result.metadatas == [{}, {}]
|
||||
|
||||
def test_prepare_documents_hash_consistency(self) -> None:
|
||||
"""Test that identical content produces identical hashes."""
|
||||
documents1: list[BaseRecord] = [
|
||||
{"content": "Same content", "metadata": {"key": "value"}}
|
||||
]
|
||||
documents2: list[BaseRecord] = [
|
||||
{"content": "Same content", "metadata": {"key": "value"}}
|
||||
]
|
||||
|
||||
result1 = _prepare_documents_for_chromadb(documents1)
|
||||
result2 = _prepare_documents_for_chromadb(documents2)
|
||||
|
||||
assert result1.ids == result2.ids
|
||||
|
||||
|
||||
class TestCreateBatchSlice:
|
||||
"""Test suite for _create_batch_slice function."""
|
||||
|
||||
def test_create_batch_slice_normal(self) -> None:
|
||||
"""Test creating a normal batch slice."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3", "id4", "id5"],
|
||||
texts=["doc1", "doc2", "doc3", "doc4", "doc5"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}, {"e": 5}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=1, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == ["id2", "id3", "id4"]
|
||||
assert batch_texts == ["doc2", "doc3", "doc4"]
|
||||
assert batch_metadatas == [{"b": 2}, {"c": 3}, {"d": 4}]
|
||||
|
||||
def test_create_batch_slice_at_end(self) -> None:
|
||||
"""Test creating a batch slice that goes beyond the end."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=2, batch_size=5
|
||||
)
|
||||
|
||||
assert batch_ids == ["id3"]
|
||||
assert batch_texts == ["doc3"]
|
||||
assert batch_metadatas == [{"c": 3}]
|
||||
|
||||
def test_create_batch_slice_empty_batch(self) -> None:
|
||||
"""Test creating a batch slice starting beyond the data."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2"], texts=["doc1", "doc2"], metadatas=[{"a": 1}, {"b": 2}]
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=5, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == []
|
||||
assert batch_texts == []
|
||||
assert batch_metadatas == []
|
||||
|
||||
def test_create_batch_slice_no_metadatas(self) -> None:
|
||||
"""Test creating a batch slice with no metadatas."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"], texts=["doc1", "doc2", "doc3"], metadatas=[]
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=2
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2"]
|
||||
assert batch_texts == ["doc1", "doc2"]
|
||||
assert batch_metadatas is None
|
||||
|
||||
def test_create_batch_slice_all_empty_metadatas(self) -> None:
|
||||
"""Test creating a batch slice where all metadatas are empty."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{}, {}, {}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2", "id3"]
|
||||
assert batch_texts == ["doc1", "doc2", "doc3"]
|
||||
assert batch_metadatas is None
|
||||
|
||||
def test_create_batch_slice_some_empty_metadatas(self) -> None:
|
||||
"""Test creating a batch slice where some metadatas are empty."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{"a": 1}, {}, {"c": 3}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2", "id3"]
|
||||
assert batch_texts == ["doc1", "doc2", "doc3"]
|
||||
assert batch_metadatas == [{"a": 1}, {}, {"c": 3}]
|
||||
|
||||
def test_create_batch_slice_zero_start_index(self) -> None:
|
||||
"""Test creating a batch slice starting from index 0."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3", "id4"],
|
||||
texts=["doc1", "doc2", "doc3", "doc4"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=2
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2"]
|
||||
assert batch_texts == ["doc1", "doc2"]
|
||||
assert batch_metadatas == [{"a": 1}, {"b": 2}]
|
||||
|
||||
def test_create_batch_slice_single_item(self) -> None:
|
||||
"""Test creating a batch slice with batch size 1."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=1, batch_size=1
|
||||
)
|
||||
|
||||
assert batch_ids == ["id2"]
|
||||
assert batch_texts == ["doc2"]
|
||||
assert batch_metadatas == [{"b": 2}]
|
||||
36
tests/rag/config/test_factory.py
Normal file
36
tests/rag/config/test_factory.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Tests for RAG config factory."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.factory import create_client
|
||||
|
||||
|
||||
def test_create_client_chromadb():
|
||||
"""Test ChromaDB client creation."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "chromadb"
|
||||
|
||||
with patch("crewai.rag.factory.require") as mock_require:
|
||||
mock_module = Mock()
|
||||
mock_client = Mock()
|
||||
mock_module.create_client.return_value = mock_client
|
||||
mock_require.return_value = mock_module
|
||||
|
||||
result = create_client(mock_config)
|
||||
|
||||
assert result == mock_client
|
||||
mock_require.assert_called_once_with(
|
||||
"crewai.rag.chromadb.factory", purpose="The 'chromadb' provider"
|
||||
)
|
||||
mock_module.create_client.assert_called_once_with(mock_config)
|
||||
|
||||
|
||||
def test_create_client_unsupported_provider():
|
||||
"""Test unsupported provider raises ValueError."""
|
||||
mock_config = Mock()
|
||||
mock_config.provider = "unsupported"
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported"):
|
||||
create_client(mock_config)
|
||||
22
tests/rag/config/test_optional_imports.py
Normal file
22
tests/rag/config/test_optional_imports.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Tests for optional imports."""
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.config.optional_imports.base import _MissingProvider
|
||||
from crewai.rag.config.optional_imports.providers import MissingChromaDBConfig
|
||||
|
||||
|
||||
def test_missing_provider_raises_runtime_error():
|
||||
"""Test that _MissingProvider raises RuntimeError on instantiation."""
|
||||
with pytest.raises(
|
||||
RuntimeError, match="provider '__missing__' requested but not installed"
|
||||
):
|
||||
_MissingProvider()
|
||||
|
||||
|
||||
def test_missing_chromadb_config_raises_runtime_error():
|
||||
"""Test that MissingChromaDBConfig raises RuntimeError on instantiation."""
|
||||
with pytest.raises(
|
||||
RuntimeError, match="provider 'chromadb' requested but not installed"
|
||||
):
|
||||
MissingChromaDBConfig()
|
||||
244
tests/rag/embeddings/test_embedding_factory.py
Normal file
244
tests/rag/embeddings/test_embedding_factory.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Tests for embedding function factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
|
||||
class TestEmbeddingFactory:
|
||||
"""Test embedding factory functions."""
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_openai(self, mock_import):
|
||||
"""Test building OpenAI embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model_name": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-key"
|
||||
assert call_kwargs["model_name"] == "text-embedding-3-small"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_azure(self, mock_import):
|
||||
"""Test building Azure embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_key": "test-azure-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_type": "azure",
|
||||
"api_version": "2023-05-15",
|
||||
"model_name": "text-embedding-3-small",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-azure-key"
|
||||
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
|
||||
assert call_kwargs["api_type"] == "azure"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_ollama(self, mock_import):
|
||||
"""Test building Ollama embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model_name": "nomic-embed-text",
|
||||
"url": "http://localhost:11434",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider"
|
||||
)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_cohere(self, mock_import):
|
||||
"""Test building Cohere embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "cohere",
|
||||
"config": {
|
||||
"api_key": "cohere-key",
|
||||
"model_name": "embed-english-v3.0",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider"
|
||||
)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_voyageai(self, mock_import):
|
||||
"""Test building VoyageAI embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "voyageai",
|
||||
"config": {
|
||||
"api_key": "voyage-key",
|
||||
"model": "voyage-2",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider"
|
||||
)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_watsonx(self, mock_import):
|
||||
"""Test building WatsonX embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "watsonx",
|
||||
"config": {
|
||||
"model_id": "ibm/slate-125m-english-rtrvr",
|
||||
"api_key": "watsonx-key",
|
||||
"url": "https://us-south.ml.cloud.ibm.com",
|
||||
"project_id": "test-project",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider"
|
||||
)
|
||||
|
||||
def test_build_embedder_unknown_provider(self):
|
||||
"""Test error handling for unknown provider."""
|
||||
config = {"provider": "unknown-provider", "config": {}}
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown provider: unknown-provider"):
|
||||
build_embedder(config)
|
||||
|
||||
def test_build_embedder_missing_provider(self):
|
||||
"""Test error handling for missing provider key."""
|
||||
config = {"config": {"api_key": "test-key"}}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
build_embedder(config)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_import_error(self, mock_import):
|
||||
"""Test error handling when provider import fails."""
|
||||
mock_import.side_effect = ImportError("Module not found")
|
||||
|
||||
config = {"provider": "openai", "config": {"api_key": "test-key"}}
|
||||
|
||||
with pytest.raises(ImportError, match="Failed to import provider openai"):
|
||||
build_embedder(config)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_custom_provider(self, mock_import):
|
||||
"""Test building custom embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_callable = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable = mock_embedding_callable
|
||||
|
||||
config = {
|
||||
"provider": "custom",
|
||||
"config": {"embedding_callable": mock_embedding_callable},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["embedding_callable"] == mock_embedding_callable
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
@patch("crewai.rag.embeddings.factory.build_embedder_from_provider")
|
||||
def test_build_embedder_with_provider_instance(
|
||||
self, mock_build_from_provider, mock_import
|
||||
):
|
||||
"""Test building embedder from provider instance."""
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
mock_provider = MagicMock(spec=BaseEmbeddingsProvider)
|
||||
mock_embedding_function = MagicMock()
|
||||
mock_build_from_provider.return_value = mock_embedding_function
|
||||
|
||||
result = build_embedder(mock_provider)
|
||||
|
||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||
assert result == mock_embedding_function
|
||||
mock_import.assert_not_called()
|
||||
122
tests/rag/embeddings/test_factory_azure.py
Normal file
122
tests/rag/embeddings/test_factory_azure.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Test Azure embedder configuration with factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
|
||||
class TestAzureEmbedderFactory:
|
||||
"""Test Azure embedder configuration with factory function."""
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_azure_with_nested_config(self, mock_import):
|
||||
"""Test Azure configuration with nested config key."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
embedder_config = {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_key": "test-azure-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_type": "azure",
|
||||
"api_version": "2023-05-15",
|
||||
"model_name": "text-embedding-3-small",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
|
||||
result = build_embedder(embedder_config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-azure-key"
|
||||
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
|
||||
assert call_kwargs["api_type"] == "azure"
|
||||
assert call_kwargs["api_version"] == "2023-05-15"
|
||||
assert call_kwargs["model_name"] == "text-embedding-3-small"
|
||||
assert call_kwargs["deployment_id"] == "test-deployment"
|
||||
|
||||
assert result == mock_embedding_function
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_regular_openai_with_nested_config(self, mock_import):
|
||||
"""Test regular OpenAI configuration with nested config."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
embedder_config = {
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "test-openai-key", "model": "text-embedding-3-large"},
|
||||
}
|
||||
|
||||
result = build_embedder(embedder_config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-openai-key"
|
||||
assert call_kwargs["model"] == "text-embedding-3-large"
|
||||
|
||||
assert result == mock_embedding_function
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_azure_provider_with_minimal_config(self, mock_import):
|
||||
"""Test Azure provider with minimal required configuration."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
embedder_config = {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(embedder_config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-key"
|
||||
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_azure_import_error(self, mock_import):
|
||||
"""Test handling of import errors for Azure provider."""
|
||||
mock_import.side_effect = ImportError("Failed to import Azure provider")
|
||||
|
||||
embedder_config = {
|
||||
"provider": "azure",
|
||||
"config": {"api_key": "test-key"},
|
||||
}
|
||||
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
build_embedder(embedder_config)
|
||||
|
||||
assert "Failed to import provider azure" in str(exc_info.value)
|
||||
792
tests/rag/qdrant/test_client.py
Normal file
792
tests/rag/qdrant/test_client.py
Normal file
@@ -0,0 +1,792 @@
|
||||
"""Tests for QdrantClient implementation."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client import QdrantClient as SyncQdrantClient
|
||||
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.types import BaseRecord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Create a mock Qdrant client."""
|
||||
return Mock(spec=SyncQdrantClient)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_qdrant_client():
|
||||
"""Create a mock async Qdrant client."""
|
||||
return Mock(spec=AsyncQdrantClient)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_qdrant_client) -> QdrantClient:
|
||||
"""Create a QdrantClient instance for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=mock_embedding)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client(mock_async_qdrant_client) -> QdrantClient:
|
||||
"""Create a QdrantClient instance with async client for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=mock_embedding
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class TestQdrantClient:
|
||||
"""Test suite for QdrantClient."""
|
||||
|
||||
def test_create_collection(self, client, mock_qdrant_client):
|
||||
"""Test that create_collection creates a new collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
client.create_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.create_collection.assert_called_once()
|
||||
call_args = mock_qdrant_client.create_collection.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["vectors_config"] is not None
|
||||
|
||||
def test_create_collection_already_exists(self, client, mock_qdrant_client):
|
||||
"""Test that create_collection raises error if collection exists."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' already exists"
|
||||
):
|
||||
client.create_collection(collection_name="test_collection")
|
||||
|
||||
def test_create_collection_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that create_collection raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method create_collection\(\) requires"
|
||||
):
|
||||
client.create_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that acreate_collection creates a new collection asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
mock_async_qdrant_client.create_collection = AsyncMock()
|
||||
|
||||
await async_client.acreate_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.create_collection.assert_called_once()
|
||||
call_args = mock_async_qdrant_client.create_collection.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["vectors_config"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection_already_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that acreate_collection raises error if collection exists."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' already exists"
|
||||
):
|
||||
await async_client.acreate_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that acreate_collection raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method acreate_collection\(\) requires"
|
||||
):
|
||||
await client.acreate_collection(collection_name="test_collection")
|
||||
|
||||
def test_get_or_create_collection_existing(self, client, mock_qdrant_client):
|
||||
"""Test get_or_create_collection returns existing collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
mock_collection_info = Mock()
|
||||
mock_qdrant_client.get_collection.return_value = mock_collection_info
|
||||
|
||||
result = client.get_or_create_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.create_collection.assert_not_called()
|
||||
assert result == mock_collection_info
|
||||
|
||||
def test_get_or_create_collection_new(self, client, mock_qdrant_client):
|
||||
"""Test get_or_create_collection creates new collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
mock_collection_info = Mock()
|
||||
mock_qdrant_client.get_collection.return_value = mock_collection_info
|
||||
|
||||
result = client.get_or_create_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.create_collection.assert_called_once()
|
||||
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
||||
assert result == mock_collection_info
|
||||
|
||||
def test_get_or_create_collection_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test get_or_create_collection raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError,
|
||||
match=r"Method get_or_create_collection\(\) requires",
|
||||
):
|
||||
client.get_or_create_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection_existing(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test aget_or_create_collection returns existing collection."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_collection_info = Mock()
|
||||
mock_async_qdrant_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection_info
|
||||
)
|
||||
|
||||
result = await async_client.aget_or_create_collection(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.get_collection.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.create_collection.assert_not_called()
|
||||
assert result == mock_collection_info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection_new(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test aget_or_create_collection creates new collection."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
mock_async_qdrant_client.create_collection = AsyncMock()
|
||||
mock_collection_info = Mock()
|
||||
mock_async_qdrant_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection_info
|
||||
)
|
||||
|
||||
result = await async_client.aget_or_create_collection(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.create_collection.assert_called_once()
|
||||
mock_async_qdrant_client.get_collection.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
assert result == mock_collection_info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection_wrong_client_type(
|
||||
self, mock_qdrant_client
|
||||
):
|
||||
"""Test aget_or_create_collection raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError,
|
||||
match=r"Method aget_or_create_collection\(\) requires",
|
||||
):
|
||||
await client.aget_or_create_collection(collection_name="test_collection")
|
||||
|
||||
def test_add_documents(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents adds documents to collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
client.embedding_function.assert_called_once_with("Test document")
|
||||
mock_qdrant_client.upsert.assert_called_once()
|
||||
|
||||
# Check upsert was called with correct parameters
|
||||
call_args = mock_qdrant_client.upsert.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert len(call_args.kwargs["points"]) == 1
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload["content"] == "Test document"
|
||||
assert point.payload["source"] == "test"
|
||||
|
||||
def test_add_documents_with_doc_id(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents uses provided doc_id."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "custom-id-123",
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
call_args = mock_qdrant_client.upsert.call_args
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.id == "custom-id-123"
|
||||
|
||||
def test_add_documents_empty_list(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents raises error for empty documents list."""
|
||||
documents: list[BaseRecord] = []
|
||||
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
def test_add_documents_collection_not_exists(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents raises error if collection doesn't exist."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
def test_add_documents_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that add_documents raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method add_documents\(\) requires"
|
||||
):
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that aadd_documents adds documents to collection asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_async_qdrant_client.upsert = AsyncMock()
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
async_client.embedding_function.assert_called_once_with("Test document")
|
||||
mock_async_qdrant_client.upsert.assert_called_once()
|
||||
|
||||
# Check upsert was called with correct parameters
|
||||
call_args = mock_async_qdrant_client.upsert.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert len(call_args.kwargs["points"]) == 1
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload["content"] == "Test document"
|
||||
assert point.payload["source"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_with_doc_id(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that aadd_documents uses provided doc_id."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_async_qdrant_client.upsert = AsyncMock()
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "custom-id-123",
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
call_args = mock_async_qdrant_client.upsert.call_args
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.id == "custom-id-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_empty_list(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that aadd_documents raises error for empty documents list."""
|
||||
documents: list[BaseRecord] = []
|
||||
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_collection_not_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that aadd_documents raises error if collection doesn't exist."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that aadd_documents raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method aadd_documents\(\) requires"
|
||||
):
|
||||
await client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
def test_search(self, client, mock_qdrant_client):
|
||||
"""Test that search returns matching documents."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_point = Mock()
|
||||
mock_point.id = "doc-123"
|
||||
mock_point.payload = {"content": "Test content", "source": "test"}
|
||||
mock_point.score = 0.95
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_client.query_points.return_value = mock_response
|
||||
|
||||
results = client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
client.embedding_function.assert_called_once_with("test query")
|
||||
mock_qdrant_client.query_points.assert_called_once()
|
||||
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 5
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc-123"
|
||||
assert results[0]["content"] == "Test content"
|
||||
assert results[0]["metadata"] == {"source": "test"}
|
||||
assert results[0]["score"] == 0.975
|
||||
|
||||
def test_search_with_filters(self, client, mock_qdrant_client):
|
||||
"""Test that search applies metadata filters correctly."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_client.query_points.return_value = mock_response
|
||||
|
||||
client.search(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
metadata_filter={"category": "tech", "status": "published"},
|
||||
)
|
||||
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
query_filter = call_args.kwargs["query_filter"]
|
||||
assert len(query_filter.must) == 2
|
||||
assert any(
|
||||
cond.key == "category" and cond.match.value == "tech"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
assert any(
|
||||
cond.key == "status" and cond.match.value == "published"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
|
||||
def test_search_with_options(self, client, mock_qdrant_client):
|
||||
"""Test that search applies limit and score_threshold correctly."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_client.query_points.return_value = mock_response
|
||||
|
||||
client.search(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
limit=5,
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["limit"] == 5
|
||||
assert call_args.kwargs["score_threshold"] == 0.8
|
||||
|
||||
def test_search_collection_not_exists(self, client, mock_qdrant_client):
|
||||
"""Test that search raises error if collection doesn't exist."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
def test_search_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that search raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method search\(\) requires"
|
||||
):
|
||||
client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that asearch returns matching documents asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_point = Mock()
|
||||
mock_point.id = "doc-123"
|
||||
mock_point.payload = {"content": "Test content", "source": "test"}
|
||||
mock_point.score = 0.95
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_async_qdrant_client.query_points = AsyncMock(return_value=mock_response)
|
||||
|
||||
results = await async_client.asearch(
|
||||
collection_name="test_collection", query="test query"
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
async_client.embedding_function.assert_called_once_with("test query")
|
||||
mock_async_qdrant_client.query_points.assert_called_once()
|
||||
|
||||
call_args = mock_async_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 5
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc-123"
|
||||
assert results[0]["content"] == "Test content"
|
||||
assert results[0]["metadata"] == {"source": "test"}
|
||||
assert results[0]["score"] == 0.975
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_with_filters(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that asearch applies metadata filters correctly."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = []
|
||||
mock_async_qdrant_client.query_points = AsyncMock(return_value=mock_response)
|
||||
|
||||
await async_client.asearch(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
metadata_filter={"category": "tech", "status": "published"},
|
||||
)
|
||||
|
||||
call_args = mock_async_qdrant_client.query_points.call_args
|
||||
query_filter = call_args.kwargs["query_filter"]
|
||||
assert len(query_filter.must) == 2
|
||||
assert any(
|
||||
cond.key == "category" and cond.match.value == "tech"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
assert any(
|
||||
cond.key == "status" and cond.match.value == "published"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_collection_not_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that asearch raises error if collection doesn't exist."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
await async_client.asearch(
|
||||
collection_name="test_collection", query="test query"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that asearch raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method asearch\(\) requires"
|
||||
):
|
||||
await client.asearch(collection_name="test_collection", query="test query")
|
||||
|
||||
def test_delete_collection(self, client, mock_qdrant_client):
|
||||
"""Test that delete_collection deletes the collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
|
||||
client.delete_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.delete_collection.assert_called_once_with(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
def test_delete_collection_not_exists(self, client, mock_qdrant_client):
|
||||
"""Test that delete_collection raises error if collection doesn't exist."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
client.delete_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
def test_delete_collection_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that delete_collection raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method delete_collection\(\) requires"
|
||||
):
|
||||
client.delete_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adelete_collection(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that adelete_collection deletes the collection asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_async_qdrant_client.delete_collection = AsyncMock()
|
||||
|
||||
await async_client.adelete_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_called_once_with(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adelete_collection_not_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that adelete_collection raises error if collection doesn't exist."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
await async_client.adelete_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adelete_collection_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that adelete_collection raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method adelete_collection\(\) requires"
|
||||
):
|
||||
await client.adelete_collection(collection_name="test_collection")
|
||||
|
||||
def test_reset(self, client, mock_qdrant_client):
|
||||
"""Test that reset deletes all collections."""
|
||||
mock_collection1 = Mock()
|
||||
mock_collection1.name = "collection1"
|
||||
mock_collection2 = Mock()
|
||||
mock_collection2.name = "collection2"
|
||||
mock_collection3 = Mock()
|
||||
mock_collection3.name = "collection3"
|
||||
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = [
|
||||
mock_collection1,
|
||||
mock_collection2,
|
||||
mock_collection3,
|
||||
]
|
||||
mock_qdrant_client.get_collections.return_value = mock_collections_response
|
||||
|
||||
client.reset()
|
||||
|
||||
mock_qdrant_client.get_collections.assert_called_once()
|
||||
assert mock_qdrant_client.delete_collection.call_count == 3
|
||||
mock_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection1"
|
||||
)
|
||||
mock_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection2"
|
||||
)
|
||||
mock_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection3"
|
||||
)
|
||||
|
||||
def test_reset_no_collections(self, client, mock_qdrant_client):
|
||||
"""Test that reset handles no collections gracefully."""
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = []
|
||||
mock_qdrant_client.get_collections.return_value = mock_collections_response
|
||||
|
||||
client.reset()
|
||||
|
||||
mock_qdrant_client.get_collections.assert_called_once()
|
||||
mock_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
def test_reset_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that reset raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method reset\(\) requires"
|
||||
):
|
||||
client.reset()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that areset deletes all collections asynchronously."""
|
||||
mock_collection1 = Mock()
|
||||
mock_collection1.name = "collection1"
|
||||
mock_collection2 = Mock()
|
||||
mock_collection2.name = "collection2"
|
||||
mock_collection3 = Mock()
|
||||
mock_collection3.name = "collection3"
|
||||
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = [
|
||||
mock_collection1,
|
||||
mock_collection2,
|
||||
mock_collection3,
|
||||
]
|
||||
mock_async_qdrant_client.get_collections = AsyncMock(
|
||||
return_value=mock_collections_response
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection = AsyncMock()
|
||||
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_qdrant_client.get_collections.assert_called_once()
|
||||
assert mock_async_qdrant_client.delete_collection.call_count == 3
|
||||
mock_async_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection1"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection2"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection3"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_no_collections(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that areset handles no collections gracefully."""
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = []
|
||||
mock_async_qdrant_client.get_collections = AsyncMock(
|
||||
return_value=mock_collections_response
|
||||
)
|
||||
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_qdrant_client.get_collections.assert_called_once()
|
||||
mock_async_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that areset raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method areset\(\) requires"
|
||||
):
|
||||
await client.areset()
|
||||
218
tests/rag/test_error_handling.py
Normal file
218
tests/rag/test_error_handling.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Tests for RAG client error handling scenarios."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
from crewai.memory.storage.rag_storage import RAGStorage # type: ignore[import-untyped]
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_connection_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles RAG client connection failures."""
|
||||
mock_get_client.side_effect = ConnectionError("Unable to connect to ChromaDB")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="connection_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_search_timeout(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles search timeouts gracefully."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = TimeoutError("Search operation timed out")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="timeout_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_collection_not_found(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles missing collections."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = ValueError(
|
||||
"Collection 'knowledge_missing' does not exist"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="missing_collection")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles invalid embedding configurations."""
|
||||
mock_get_client.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.build_embedder"
|
||||
) as mock_get_embedding:
|
||||
mock_get_embedding.side_effect = ValueError(
|
||||
"Unsupported provider: invalid_provider"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
|
||||
KnowledgeStorage(
|
||||
embedder={"provider": "invalid_provider"},
|
||||
collection_name="invalid_embedding_test",
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_client_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles RAG client failures in memory operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = RuntimeError("ChromaDB server error")
|
||||
|
||||
storage = RAGStorage("short_term", crew=None)
|
||||
|
||||
results = storage.search("test query")
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_save_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles save operation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.add_documents.side_effect = Exception("Failed to add documents")
|
||||
|
||||
storage = RAGStorage("long_term", crew=None)
|
||||
|
||||
storage.save("test memory", {"key": "value"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_reset_readonly_database(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage reset handles readonly database errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception(
|
||||
"attempt to write a readonly database"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="readonly_test")
|
||||
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_reset_collection_does_not_exist(
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test KnowledgeStorage reset handles non-existent collections."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception("Collection does not exist")
|
||||
|
||||
storage = KnowledgeStorage(collection_name="nonexistent_test")
|
||||
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_reset_failure_propagation(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage reset propagates unexpected errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception("Unexpected database error")
|
||||
|
||||
storage = RAGStorage("entities", crew=None)
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="An error occurred while resetting the entities memory"
|
||||
):
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_malformed_search_results(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles malformed search results."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.return_value = [
|
||||
{"content": "valid result", "metadata": {"source": "test"}},
|
||||
{"invalid": "missing content field", "metadata": {"source": "test"}},
|
||||
None,
|
||||
{"content": None, "metadata": {"source": "test"}},
|
||||
]
|
||||
|
||||
storage = KnowledgeStorage(collection_name="malformed_test")
|
||||
|
||||
results = storage.search(["test query"])
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 4
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_network_interruption(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles network interruptions during operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
mock_client.search.side_effect = [
|
||||
ConnectionError("Network interruption"),
|
||||
[{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}],
|
||||
]
|
||||
|
||||
storage = KnowledgeStorage(collection_name="network_test")
|
||||
|
||||
first_attempt = storage.search(["test query"])
|
||||
assert first_attempt == []
|
||||
|
||||
mock_client.search.side_effect = None
|
||||
mock_client.search.return_value = [
|
||||
{"content": "recovered result", "score": 0.8, "metadata": {"source": "test"}}
|
||||
]
|
||||
|
||||
second_attempt = storage.search(["test query"])
|
||||
assert len(second_attempt) == 1
|
||||
assert second_attempt[0]["content"] == "recovered result"
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_collection_creation_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles collection creation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.side_effect = Exception(
|
||||
"Failed to create collection"
|
||||
)
|
||||
|
||||
storage = RAGStorage("user_memory", crew=None)
|
||||
|
||||
storage.save("test data", {"metadata": "test"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_embedding_dimension_mismatch_detailed(
|
||||
mock_get_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test detailed handling of embedding dimension mismatch errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.return_value = None
|
||||
mock_client.add_documents.side_effect = Exception(
|
||||
"Embedding dimension mismatch: expected 384, got 1536"
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="dimension_detailed_test")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
storage.save(["test document"])
|
||||
|
||||
assert "Embedding dimension mismatch" in str(exc_info.value)
|
||||
assert "Make sure you're using the same embedding model" in str(exc_info.value)
|
||||
assert "crewai reset-memories -a" in str(exc_info.value)
|
||||
Reference in New Issue
Block a user