mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: add configurable search parameters for RAG, knowledge, and memory (#3531)
- Add limit and score_threshold to BaseRagConfig, propagate to clients - Update default search params in RAG storage, knowledge, and memory (limit=5, threshold=0.6) - Fix linting (ruff, mypy, PERF203) and refactor save logic - Update tests for new defaults and ChromaDB behavior
This commit is contained in:
@@ -236,7 +236,7 @@ class TestChromaDBClient:
|
||||
def test_add_documents(self, client, mock_chromadb_client) -> None:
|
||||
"""Test that add_documents adds documents to collection."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
@@ -247,7 +247,7 @@ class TestChromaDBClient:
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=client.embedding_function,
|
||||
)
|
||||
@@ -262,7 +262,7 @@ class TestChromaDBClient:
|
||||
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with custom document IDs."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
@@ -288,7 +288,7 @@ class TestChromaDBClient:
|
||||
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with documents that have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
@@ -308,7 +308,7 @@ class TestChromaDBClient:
|
||||
) -> None:
|
||||
"""Test add_documents when all documents have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document 1"},
|
||||
@@ -335,7 +335,7 @@ class TestChromaDBClient:
|
||||
) -> None:
|
||||
"""Test that aadd_documents adds documents to collection asynchronously."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
@@ -350,7 +350,7 @@ class TestChromaDBClient:
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=async_client.embedding_function,
|
||||
)
|
||||
@@ -368,7 +368,7 @@ class TestChromaDBClient:
|
||||
) -> None:
|
||||
"""Test aadd_documents with custom document IDs."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
@@ -401,7 +401,7 @@ class TestChromaDBClient:
|
||||
) -> None:
|
||||
"""Test aadd_documents with documents that have no metadata."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
@@ -434,7 +434,7 @@ class TestChromaDBClient:
|
||||
"""Test that search queries the collection correctly."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["doc1", "doc2"]],
|
||||
"documents": [["Document 1", "Document 2"]],
|
||||
@@ -444,13 +444,13 @@ class TestChromaDBClient:
|
||||
|
||||
results = client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
mock_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=client.embedding_function,
|
||||
)
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=10,
|
||||
n_results=5,
|
||||
where=None,
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
@@ -466,7 +466,7 @@ class TestChromaDBClient:
|
||||
"""Test search with optional parameters."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["doc1", "doc2", "doc3"]],
|
||||
"documents": [["Document 1", "Document 2", "Document 3"]],
|
||||
@@ -499,7 +499,7 @@ class TestChromaDBClient:
|
||||
"""Test that asearch queries the collection correctly."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_collection.query = AsyncMock(
|
||||
@@ -515,13 +515,13 @@ class TestChromaDBClient:
|
||||
collection_name="test_collection", query="test query"
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=async_client.embedding_function,
|
||||
)
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=10,
|
||||
n_results=5,
|
||||
where=None,
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
@@ -540,7 +540,7 @@ class TestChromaDBClient:
|
||||
"""Test asearch with optional parameters."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_collection.query = AsyncMock(
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client import QdrantClient as SyncQdrantClient
|
||||
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
@@ -435,7 +436,7 @@ class TestQdrantClient:
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 10
|
||||
assert call_args.kwargs["limit"] == 5
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
@@ -540,7 +541,7 @@ class TestQdrantClient:
|
||||
call_args = mock_async_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 10
|
||||
assert call_args.kwargs["limit"] == 5
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user