feat: add configurable search parameters for RAG, knowledge, and memory (#3531)

- Add limit and score_threshold to BaseRagConfig, propagate to clients  
- Update default search params in RAG storage, knowledge, and memory (limit=5, threshold=0.6)  
- Fix linting (ruff, mypy, PERF203) and refactor save logic  
- Update tests for new defaults and ChromaDB behavior
This commit is contained in:
Greyson LaLonde
2025-09-18 16:58:03 -04:00
committed by GitHub
parent 578fa8c2e4
commit d4aa676195
18 changed files with 173 additions and 118 deletions

View File

@@ -43,7 +43,7 @@ class Knowledge(BaseModel):
self.sources = sources
def query(
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
) -> list[SearchResult]:
"""
Query across all knowledge sources to find the most relevant information.

View File

@@ -9,8 +9,8 @@ class KnowledgeConfig(BaseModel):
score_threshold (float): The minimum score for a document to be considered relevant.
"""
results_limit: int = Field(default=3, description="The number of results to return")
results_limit: int = Field(default=5, description="The number of results to return")
score_threshold: float = Field(
default=0.35,
default=0.6,
description="The minimum score for a result to be considered relevant",
)

View File

@@ -11,9 +11,9 @@ class BaseKnowledgeStorage(ABC):
def search(
self,
query: list[str],
limit: int = 3,
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
score_threshold: float = 0.6,
) -> list[SearchResult]:
"""Search for documents in the knowledge base."""

View File

@@ -1,4 +1,5 @@
import logging
import traceback
import warnings
from typing import Any, cast
@@ -49,9 +50,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def search(
self,
query: list[str],
limit: int = 3,
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
score_threshold: float = 0.6,
) -> list[SearchResult]:
try:
if not query:
@@ -73,7 +74,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
score_threshold=score_threshold,
)
except Exception as e:
logging.error(f"Error during knowledge search: {e!s}")
logging.error(
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
)
return []
def reset(self) -> None:
@@ -86,7 +89,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
)
client.delete_collection(collection_name=collection_name)
except Exception as e:
logging.error(f"Error during knowledge reset: {e!s}")
logging.error(
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
)
def save(self, documents: list[str]) -> None:
try:

View File

@@ -1,20 +1,20 @@
from typing import Any
import time
from typing import Any
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class EntityMemory(Memory):
@@ -31,10 +31,10 @@ class EntityMemory(Memory):
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
except ImportError as e:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
) from e
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
@@ -90,8 +90,8 @@ class EntityMemory(Memory):
saved_count = 0
errors = []
try:
for item in items:
def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
"""Save a single item and return success status."""
try:
if self._memory_provider == "mem0":
data = f"""
@@ -103,10 +103,18 @@ class EntityMemory(Memory):
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
saved_count += 1
super(EntityMemory, self).save(data, item.metadata)
return True, None
except Exception as e:
errors.append(f"{item.name}: {str(e)}")
return False, f"{item.name}: {e!s}"
try:
for item in items:
success, error = save_single_item(item)
if success:
saved_count += 1
else:
errors.append(error)
if is_batch:
emit_value = f"Saved {saved_count} entities"
@@ -153,8 +161,8 @@ class EntityMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,
@@ -206,4 +214,6 @@ class EntityMemory(Memory):
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the entity memory: {e}")
raise Exception(
f"An error occurred while resetting the entity memory: {e}"
) from e

View File

@@ -1,41 +1,41 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
import time
from typing import TYPE_CHECKING, Any
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.interface import Storage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
if TYPE_CHECKING:
from crewai.memory.storage.mem0_storage import Mem0Storage
class ExternalMemory(Memory):
def __init__(self, storage: Optional[Storage] = None, **data: Any):
def __init__(self, storage: Storage | None = None, **data: Any):
super().__init__(storage=storage, **data)
@staticmethod
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
@staticmethod
def external_supported_storages() -> Dict[str, Any]:
def external_supported_storages() -> dict[str, Any]:
return {
"mem0": ExternalMemory._configure_mem0,
}
@staticmethod
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
if not embedder_config:
raise ValueError("embedder_config is required")
@@ -52,7 +52,7 @@ class ExternalMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Saves a value into the external storage."""
crewai_event_bus.emit(
@@ -103,8 +103,8 @@ class ExternalMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel
@@ -12,8 +12,8 @@ class Memory(BaseModel):
Base class for memory, now supporting agent tags and generic metadata.
"""
embedder_config: Optional[Dict[str, Any]] = None
crew: Optional[Any] = None
embedder_config: dict[str, Any] | None = None
crew: Any | None = None
storage: Any
_agent: Optional["Agent"] = None
@@ -45,7 +45,7 @@ class Memory(BaseModel):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
metadata = metadata or {}
@@ -54,9 +54,9 @@ class Memory(BaseModel):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
)

View File

@@ -1,20 +1,20 @@
from typing import Any, Dict, Optional
import time
from typing import Any
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class ShortTermMemory(Memory):
@@ -26,17 +26,17 @@ class ShortTermMemory(Memory):
MemoryItem instances.
"""
_memory_provider: Optional[str] = PrivateAttr()
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
except ImportError as e:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
) from e
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
@@ -56,7 +56,7 @@ class ShortTermMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None:
crewai_event_bus.emit(
self,
@@ -112,8 +112,8 @@ class ShortTermMemory(Memory):
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
limit: int = 5,
score_threshold: float = 0.6,
):
crewai_event_bus.emit(
self,
@@ -167,4 +167,4 @@ class ShortTermMemory(Memory):
except Exception as e:
raise Exception(
f"An error occurred while resetting the short-term memory: {e}"
)
) from e

View File

@@ -151,7 +151,7 @@ class Mem0Storage(Storage):
self.memory.add(conversations, **params)
def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35
self, query: str, limit: int = 5, score_threshold: float = 0.6
) -> list[Any]:
params = {
"query": query,

View File

@@ -1,4 +1,5 @@
import logging
import traceback
import warnings
from typing import Any
@@ -86,14 +87,16 @@ class RAGStorage(BaseRAGStorage):
client.add_documents(collection_name=collection_name, documents=[document])
except Exception as e:
logging.error(f"Error during {self.type} save: {e!s}")
logging.error(
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
)
def search(
self,
query: str,
limit: int = 3,
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
score_threshold: float = 0.6,
) -> list[Any]:
try:
client = self._get_client()
@@ -110,7 +113,9 @@ class RAGStorage(BaseRAGStorage):
score_threshold=score_threshold,
)
except Exception as e:
logging.error(f"Error during {self.type} search: {e!s}")
logging.error(
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
)
return []
def reset(self) -> None:

View File

@@ -42,21 +42,29 @@ class ChromaDBClient(BaseClient):
Attributes:
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
def __init__(
self,
client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
Args:
client: Pre-configured ChromaDB client instance.
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -301,7 +309,7 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = self.client.get_collection(
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
@@ -345,7 +353,7 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = await self.client.get_collection(
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
@@ -390,9 +398,14 @@ class ChromaDBClient(BaseClient):
"Use asearch() for AsyncClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = self.client.get_collection(
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
@@ -448,9 +461,14 @@ class ChromaDBClient(BaseClient):
"Use search() for ClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = await self.client.get_collection(
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)

View File

@@ -39,4 +39,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
return ChromaDBClient(
client=client,
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
)

View File

@@ -14,3 +14,5 @@ class BaseRagConfig:
provider: SupportedProvider = field(init=False)
embedding_function: Any | None = field(default=None)
limit: int = field(default=5)
score_threshold: float = field(default=0.6)

View File

@@ -6,8 +6,8 @@ from typing_extensions import Unpack
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
BaseCollectionParams,
BaseCollectionSearchParams,
)
from crewai.rag.core.exceptions import ClientMethodMismatchError
@@ -18,11 +18,11 @@ from crewai.rag.qdrant.types import (
QdrantCollectionCreateParams,
)
from crewai.rag.qdrant.utils import (
_create_point_from_document,
_get_collection_params,
_is_async_client,
_is_async_embedding_function,
_is_sync_client,
_create_point_from_document,
_get_collection_params,
_prepare_search_params,
_process_search_results,
)
@@ -38,21 +38,29 @@ class QdrantClient(BaseClient):
Attributes:
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
def __init__(
self,
client: QdrantClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
) -> None:
"""Initialize QdrantClient with client and embedding function.
Args:
client: Pre-configured Qdrant client instance.
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant.
@@ -332,9 +340,9 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
@@ -387,9 +395,9 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")

View File

@@ -1,6 +1,7 @@
"""Factory functions for creating Qdrant clients from configuration."""
from qdrant_client import QdrantClient as SyncQdrantClientBase
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
@@ -17,5 +18,8 @@ def create_client(config: QdrantConfig) -> QdrantClient:
qdrant_client = SyncQdrantClientBase(**config.options)
return QdrantClient(
client=qdrant_client, embedding_function=config.embedding_function
client=qdrant_client,
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
)

View File

@@ -41,9 +41,9 @@ class BaseRAGStorage(ABC):
def search(
self,
query: str,
limit: int = 3,
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for entries in the storage."""

View File

@@ -236,7 +236,7 @@ class TestChromaDBClient:
def test_add_documents(self, client, mock_chromadb_client) -> None:
"""Test that add_documents adds documents to collection."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{
@@ -247,7 +247,7 @@ class TestChromaDBClient:
client.add_documents(collection_name="test_collection", documents=documents)
mock_chromadb_client.get_collection.assert_called_once_with(
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=client.embedding_function,
)
@@ -262,7 +262,7 @@ class TestChromaDBClient:
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
"""Test add_documents with custom document IDs."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{
@@ -288,7 +288,7 @@ class TestChromaDBClient:
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
"""Test add_documents with documents that have no metadata."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"content": "Document without metadata"},
@@ -308,7 +308,7 @@ class TestChromaDBClient:
) -> None:
"""Test add_documents when all documents have no metadata."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"content": "Document 1"},
@@ -335,7 +335,7 @@ class TestChromaDBClient:
) -> None:
"""Test that aadd_documents adds documents to collection asynchronously."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
@@ -350,7 +350,7 @@ class TestChromaDBClient:
collection_name="test_collection", documents=documents
)
mock_async_chromadb_client.get_collection.assert_called_once_with(
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=async_client.embedding_function,
)
@@ -368,7 +368,7 @@ class TestChromaDBClient:
) -> None:
"""Test aadd_documents with custom document IDs."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
@@ -401,7 +401,7 @@ class TestChromaDBClient:
) -> None:
"""Test aadd_documents with documents that have no metadata."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
@@ -434,7 +434,7 @@ class TestChromaDBClient:
"""Test that search queries the collection correctly."""
mock_collection = Mock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
mock_collection.query.return_value = {
"ids": [["doc1", "doc2"]],
"documents": [["Document 1", "Document 2"]],
@@ -444,13 +444,13 @@ class TestChromaDBClient:
results = client.search(collection_name="test_collection", query="test query")
mock_chromadb_client.get_collection.assert_called_once_with(
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=client.embedding_function,
)
mock_collection.query.assert_called_once_with(
query_texts=["test query"],
n_results=10,
n_results=5,
where=None,
where_document=None,
include=["metadatas", "documents", "distances"],
@@ -466,7 +466,7 @@ class TestChromaDBClient:
"""Test search with optional parameters."""
mock_collection = Mock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_chromadb_client.get_collection.return_value = mock_collection
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
mock_collection.query.return_value = {
"ids": [["doc1", "doc2", "doc3"]],
"documents": [["Document 1", "Document 2", "Document 3"]],
@@ -499,7 +499,7 @@ class TestChromaDBClient:
"""Test that asearch queries the collection correctly."""
mock_collection = AsyncMock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
mock_collection.query = AsyncMock(
@@ -515,13 +515,13 @@ class TestChromaDBClient:
collection_name="test_collection", query="test query"
)
mock_async_chromadb_client.get_collection.assert_called_once_with(
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
embedding_function=async_client.embedding_function,
)
mock_collection.query.assert_called_once_with(
query_texts=["test query"],
n_results=10,
n_results=5,
where=None,
where_document=None,
include=["metadatas", "documents", "distances"],
@@ -540,7 +540,7 @@ class TestChromaDBClient:
"""Test asearch with optional parameters."""
mock_collection = AsyncMock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_async_chromadb_client.get_collection = AsyncMock(
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
mock_collection.query = AsyncMock(

View File

@@ -3,7 +3,8 @@
from unittest.mock import AsyncMock, Mock
import pytest
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from qdrant_client import AsyncQdrantClient
from qdrant_client import QdrantClient as SyncQdrantClient
from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.qdrant.client import QdrantClient
@@ -435,7 +436,7 @@ class TestQdrantClient:
call_args = mock_qdrant_client.query_points.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
assert call_args.kwargs["limit"] == 10
assert call_args.kwargs["limit"] == 5
assert call_args.kwargs["with_payload"] is True
assert call_args.kwargs["with_vectors"] is False
@@ -540,7 +541,7 @@ class TestQdrantClient:
call_args = mock_async_qdrant_client.query_points.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
assert call_args.kwargs["limit"] == 10
assert call_args.kwargs["limit"] == 5
assert call_args.kwargs["with_payload"] is True
assert call_args.kwargs["with_vectors"] is False