feat: add configurable search parameters for RAG, knowledge, and memory (#3531)

- Add limit and score_threshold to BaseRagConfig, propagate to clients  
- Update default search params in RAG storage, knowledge, and memory (limit=5, threshold=0.6)  
- Fix linting (ruff, mypy, PERF203) and refactor save logic  
- Update tests for new defaults and ChromaDB behavior
This commit is contained in:
Greyson LaLonde
2025-09-18 16:58:03 -04:00
committed by GitHub
parent 578fa8c2e4
commit d4aa676195
18 changed files with 173 additions and 118 deletions

View File

@@ -43,7 +43,7 @@ class Knowledge(BaseModel):
self.sources = sources self.sources = sources
def query( def query(
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35 self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
) -> list[SearchResult]: ) -> list[SearchResult]:
""" """
Query across all knowledge sources to find the most relevant information. Query across all knowledge sources to find the most relevant information.

View File

@@ -9,8 +9,8 @@ class KnowledgeConfig(BaseModel):
score_threshold (float): The minimum score for a document to be considered relevant. score_threshold (float): The minimum score for a document to be considered relevant.
""" """
results_limit: int = Field(default=3, description="The number of results to return") results_limit: int = Field(default=5, description="The number of results to return")
score_threshold: float = Field( score_threshold: float = Field(
default=0.35, default=0.6,
description="The minimum score for a result to be considered relevant", description="The minimum score for a result to be considered relevant",
) )

View File

@@ -11,9 +11,9 @@ class BaseKnowledgeStorage(ABC):
def search( def search(
self, self,
query: list[str], query: list[str],
limit: int = 3, limit: int = 5,
metadata_filter: dict[str, Any] | None = None, metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.6,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search for documents in the knowledge base.""" """Search for documents in the knowledge base."""

View File

@@ -1,4 +1,5 @@
import logging import logging
import traceback
import warnings import warnings
from typing import Any, cast from typing import Any, cast
@@ -49,9 +50,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def search( def search(
self, self,
query: list[str], query: list[str],
limit: int = 3, limit: int = 5,
metadata_filter: dict[str, Any] | None = None, metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.6,
) -> list[SearchResult]: ) -> list[SearchResult]:
try: try:
if not query: if not query:
@@ -73,7 +74,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
score_threshold=score_threshold, score_threshold=score_threshold,
) )
except Exception as e: except Exception as e:
logging.error(f"Error during knowledge search: {e!s}") logging.error(
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
)
return [] return []
def reset(self) -> None: def reset(self) -> None:
@@ -86,7 +89,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
) )
client.delete_collection(collection_name=collection_name) client.delete_collection(collection_name=collection_name)
except Exception as e: except Exception as e:
logging.error(f"Error during knowledge reset: {e!s}") logging.error(
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
)
def save(self, documents: list[str]) -> None: def save(self, documents: list[str]) -> None:
try: try:

View File

@@ -1,20 +1,20 @@
from typing import Any
import time import time
from typing import Any
from pydantic import PrivateAttr from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.entity.entity_memory_item import EntityMemoryItem from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class EntityMemory(Memory): class EntityMemory(Memory):
@@ -31,10 +31,10 @@ class EntityMemory(Memory):
if memory_provider == "mem0": if memory_provider == "mem0":
try: try:
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`." "Mem0 is not installed. Please install it with `pip install mem0ai`."
) ) from e
config = embedder_config.get("config") if embedder_config else None config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config) storage = Mem0Storage(type="short_term", crew=crew, config=config)
else: else:
@@ -90,8 +90,8 @@ class EntityMemory(Memory):
saved_count = 0 saved_count = 0
errors = [] errors = []
try: def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
for item in items: """Save a single item and return success status."""
try: try:
if self._memory_provider == "mem0": if self._memory_provider == "mem0":
data = f""" data = f"""
@@ -103,10 +103,18 @@ class EntityMemory(Memory):
else: else:
data = f"{item.name}({item.type}): {item.description}" data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata) super(EntityMemory, self).save(data, item.metadata)
saved_count += 1 return True, None
except Exception as e: except Exception as e:
errors.append(f"{item.name}: {str(e)}") return False, f"{item.name}: {e!s}"
try:
for item in items:
success, error = save_single_item(item)
if success:
saved_count += 1
else:
errors.append(error)
if is_batch: if is_batch:
emit_value = f"Saved {saved_count} entities" emit_value = f"Saved {saved_count} entities"
@@ -153,8 +161,8 @@ class EntityMemory(Memory):
def search( def search(
self, self,
query: str, query: str,
limit: int = 3, limit: int = 5,
score_threshold: float = 0.35, score_threshold: float = 0.6,
): ):
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
@@ -206,4 +214,6 @@ class EntityMemory(Memory):
try: try:
self.storage.reset() self.storage.reset()
except Exception as e: except Exception as e:
raise Exception(f"An error occurred while resetting the entity memory: {e}") raise Exception(
f"An error occurred while resetting the entity memory: {e}"
) from e

View File

@@ -1,41 +1,41 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
import time import time
from typing import TYPE_CHECKING, Any
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.external.external_memory_item import ExternalMemoryItem from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory from crewai.memory.memory import Memory
from crewai.memory.storage.interface import Storage from crewai.memory.storage.interface import Storage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
class ExternalMemory(Memory): class ExternalMemory(Memory):
def __init__(self, storage: Optional[Storage] = None, **data: Any): def __init__(self, storage: Storage | None = None, **data: Any):
super().__init__(storage=storage, **data) super().__init__(storage=storage, **data)
@staticmethod @staticmethod
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage": def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config) return Mem0Storage(type="external", crew=crew, config=config)
@staticmethod @staticmethod
def external_supported_storages() -> Dict[str, Any]: def external_supported_storages() -> dict[str, Any]:
return { return {
"mem0": ExternalMemory._configure_mem0, "mem0": ExternalMemory._configure_mem0,
} }
@staticmethod @staticmethod
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage: def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
if not embedder_config: if not embedder_config:
raise ValueError("embedder_config is required") raise ValueError("embedder_config is required")
@@ -52,7 +52,7 @@ class ExternalMemory(Memory):
def save( def save(
self, self,
value: Any, value: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Saves a value into the external storage.""" """Saves a value into the external storage."""
crewai_event_bus.emit( crewai_event_bus.emit(
@@ -103,8 +103,8 @@ class ExternalMemory(Memory):
def search( def search(
self, self,
query: str, query: str,
limit: int = 3, limit: int = 5,
score_threshold: float = 0.35, score_threshold: float = 0.6,
): ):
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
@@ -12,8 +12,8 @@ class Memory(BaseModel):
Base class for memory, now supporting agent tags and generic metadata. Base class for memory, now supporting agent tags and generic metadata.
""" """
embedder_config: Optional[Dict[str, Any]] = None embedder_config: dict[str, Any] | None = None
crew: Optional[Any] = None crew: Any | None = None
storage: Any storage: Any
_agent: Optional["Agent"] = None _agent: Optional["Agent"] = None
@@ -45,7 +45,7 @@ class Memory(BaseModel):
def save( def save(
self, self,
value: Any, value: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
) -> None: ) -> None:
metadata = metadata or {} metadata = metadata or {}
@@ -54,9 +54,9 @@ class Memory(BaseModel):
def search( def search(
self, self,
query: str, query: str,
limit: int = 3, limit: int = 5,
score_threshold: float = 0.35, score_threshold: float = 0.6,
) -> List[Any]: ) -> list[Any]:
return self.storage.search( return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold query=query, limit=limit, score_threshold=score_threshold
) )

View File

@@ -1,20 +1,20 @@
from typing import Any, Dict, Optional
import time import time
from typing import Any
from pydantic import PrivateAttr from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.memory import Memory from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class ShortTermMemory(Memory): class ShortTermMemory(Memory):
@@ -26,17 +26,17 @@ class ShortTermMemory(Memory):
MemoryItem instances. MemoryItem instances.
""" """
_memory_provider: Optional[str] = PrivateAttr() _memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None): def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = embedder_config.get("provider") if embedder_config else None memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0": if memory_provider == "mem0":
try: try:
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`." "Mem0 is not installed. Please install it with `pip install mem0ai`."
) ) from e
config = embedder_config.get("config") if embedder_config else None config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config) storage = Mem0Storage(type="short_term", crew=crew, config=config)
else: else:
@@ -56,7 +56,7 @@ class ShortTermMemory(Memory):
def save( def save(
self, self,
value: Any, value: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
) -> None: ) -> None:
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
@@ -112,8 +112,8 @@ class ShortTermMemory(Memory):
def search( def search(
self, self,
query: str, query: str,
limit: int = 3, limit: int = 5,
score_threshold: float = 0.35, score_threshold: float = 0.6,
): ):
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
@@ -167,4 +167,4 @@ class ShortTermMemory(Memory):
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"An error occurred while resetting the short-term memory: {e}" f"An error occurred while resetting the short-term memory: {e}"
) ) from e

View File

@@ -151,7 +151,7 @@ class Mem0Storage(Storage):
self.memory.add(conversations, **params) self.memory.add(conversations, **params)
def search( def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35 self, query: str, limit: int = 5, score_threshold: float = 0.6
) -> list[Any]: ) -> list[Any]:
params = { params = {
"query": query, "query": query,

View File

@@ -1,4 +1,5 @@
import logging import logging
import traceback
import warnings import warnings
from typing import Any from typing import Any
@@ -86,14 +87,16 @@ class RAGStorage(BaseRAGStorage):
client.add_documents(collection_name=collection_name, documents=[document]) client.add_documents(collection_name=collection_name, documents=[document])
except Exception as e: except Exception as e:
logging.error(f"Error during {self.type} save: {e!s}") logging.error(
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
)
def search( def search(
self, self,
query: str, query: str,
limit: int = 3, limit: int = 5,
filter: dict[str, Any] | None = None, filter: dict[str, Any] | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.6,
) -> list[Any]: ) -> list[Any]:
try: try:
client = self._get_client() client = self._get_client()
@@ -110,7 +113,9 @@ class RAGStorage(BaseRAGStorage):
score_threshold=score_threshold, score_threshold=score_threshold,
) )
except Exception as e: except Exception as e:
logging.error(f"Error during {self.type} search: {e!s}") logging.error(
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
)
return [] return []
def reset(self) -> None: def reset(self) -> None:

View File

@@ -42,21 +42,29 @@ class ChromaDBClient(BaseClient):
Attributes: Attributes:
client: ChromaDB client instance (ClientAPI or AsyncClientAPI). client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
embedding_function: Function to generate embeddings for documents. embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
""" """
def __init__( def __init__(
self, self,
client: ChromaDBClientType, client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction, embedding_function: ChromaEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
) -> None: ) -> None:
"""Initialize ChromaDBClient with client and embedding function. """Initialize ChromaDBClient with client and embedding function.
Args: Args:
client: Pre-configured ChromaDB client instance. client: Pre-configured ChromaDB client instance.
embedding_function: Embedding function for text to vector conversion. embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
""" """
self.client = client self.client = client
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
def create_collection( def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams] self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -301,7 +309,7 @@ class ChromaDBClient(BaseClient):
if not documents: if not documents:
raise ValueError("Documents list cannot be empty") raise ValueError("Documents list cannot be empty")
collection = self.client.get_collection( collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name), name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
) )
@@ -345,7 +353,7 @@ class ChromaDBClient(BaseClient):
if not documents: if not documents:
raise ValueError("Documents list cannot be empty") raise ValueError("Documents list cannot be empty")
collection = await self.client.get_collection( collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name), name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
) )
@@ -390,9 +398,14 @@ class ChromaDBClient(BaseClient):
"Use asearch() for AsyncClientAPI." "Use asearch() for AsyncClientAPI."
) )
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs) params = _extract_search_params(kwargs)
collection = self.client.get_collection( collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name), name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
) )
@@ -448,9 +461,14 @@ class ChromaDBClient(BaseClient):
"Use search() for ClientAPI." "Use search() for ClientAPI."
) )
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs) params = _extract_search_params(kwargs)
collection = await self.client.get_collection( collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name), name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
) )

View File

@@ -39,4 +39,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
return ChromaDBClient( return ChromaDBClient(
client=client, client=client,
embedding_function=config.embedding_function, embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
) )

View File

@@ -14,3 +14,5 @@ class BaseRagConfig:
provider: SupportedProvider = field(init=False) provider: SupportedProvider = field(init=False)
embedding_function: Any | None = field(default=None) embedding_function: Any | None = field(default=None)
limit: int = field(default=5)
score_threshold: float = field(default=0.6)

View File

@@ -6,8 +6,8 @@ from typing_extensions import Unpack
from crewai.rag.core.base_client import ( from crewai.rag.core.base_client import (
BaseClient, BaseClient,
BaseCollectionParams,
BaseCollectionAddParams, BaseCollectionAddParams,
BaseCollectionParams,
BaseCollectionSearchParams, BaseCollectionSearchParams,
) )
from crewai.rag.core.exceptions import ClientMethodMismatchError from crewai.rag.core.exceptions import ClientMethodMismatchError
@@ -18,11 +18,11 @@ from crewai.rag.qdrant.types import (
QdrantCollectionCreateParams, QdrantCollectionCreateParams,
) )
from crewai.rag.qdrant.utils import ( from crewai.rag.qdrant.utils import (
_create_point_from_document,
_get_collection_params,
_is_async_client, _is_async_client,
_is_async_embedding_function, _is_async_embedding_function,
_is_sync_client, _is_sync_client,
_create_point_from_document,
_get_collection_params,
_prepare_search_params, _prepare_search_params,
_process_search_results, _process_search_results,
) )
@@ -38,21 +38,29 @@ class QdrantClient(BaseClient):
Attributes: Attributes:
client: Qdrant client instance (QdrantClient or AsyncQdrantClient). client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
embedding_function: Function to generate embeddings for documents. embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
""" """
def __init__( def __init__(
self, self,
client: QdrantClientType, client: QdrantClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction, embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
) -> None: ) -> None:
"""Initialize QdrantClient with client and embedding function. """Initialize QdrantClient with client and embedding function.
Args: Args:
client: Pre-configured Qdrant client instance. client: Pre-configured Qdrant client instance.
embedding_function: Embedding function for text to vector conversion. embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
""" """
self.client = client self.client = client
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None: def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant. """Create a new collection in Qdrant.
@@ -332,9 +340,9 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"] collection_name = kwargs["collection_name"]
query = kwargs["query"] query = kwargs["query"]
limit = kwargs.get("limit", 10) limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter") metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold") score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not self.client.collection_exists(collection_name): if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist") raise ValueError(f"Collection '{collection_name}' does not exist")
@@ -387,9 +395,9 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"] collection_name = kwargs["collection_name"]
query = kwargs["query"] query = kwargs["query"]
limit = kwargs.get("limit", 10) limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter") metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold") score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not await self.client.collection_exists(collection_name): if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist") raise ValueError(f"Collection '{collection_name}' does not exist")

View File

@@ -1,6 +1,7 @@
"""Factory functions for creating Qdrant clients from configuration.""" """Factory functions for creating Qdrant clients from configuration."""
from qdrant_client import QdrantClient as SyncQdrantClientBase from qdrant_client import QdrantClient as SyncQdrantClientBase
from crewai.rag.qdrant.client import QdrantClient from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig from crewai.rag.qdrant.config import QdrantConfig
@@ -17,5 +18,8 @@ def create_client(config: QdrantConfig) -> QdrantClient:
qdrant_client = SyncQdrantClientBase(**config.options) qdrant_client = SyncQdrantClientBase(**config.options)
return QdrantClient( return QdrantClient(
client=qdrant_client, embedding_function=config.embedding_function client=qdrant_client,
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
) )

View File

@@ -41,9 +41,9 @@ class BaseRAGStorage(ABC):
def search( def search(
self, self,
query: str, query: str,
limit: int = 3, limit: int = 5,
filter: dict[str, Any] | None = None, filter: dict[str, Any] | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.6,
) -> list[Any]: ) -> list[Any]:
"""Search for entries in the storage.""" """Search for entries in the storage."""

View File

@@ -236,7 +236,7 @@ class TestChromaDBClient:
def test_add_documents(self, client, mock_chromadb_client) -> None: def test_add_documents(self, client, mock_chromadb_client) -> None:
"""Test that add_documents adds documents to collection.""" """Test that add_documents adds documents to collection."""
mock_collection = Mock() mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [ documents: list[BaseRecord] = [
{ {
@@ -247,7 +247,7 @@ class TestChromaDBClient:
client.add_documents(collection_name="test_collection", documents=documents) client.add_documents(collection_name="test_collection", documents=documents)
mock_chromadb_client.get_collection.assert_called_once_with( mock_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection", name="test_collection",
embedding_function=client.embedding_function, embedding_function=client.embedding_function,
) )
@@ -262,7 +262,7 @@ class TestChromaDBClient:
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None: def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
"""Test add_documents with custom document IDs.""" """Test add_documents with custom document IDs."""
mock_collection = Mock() mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [ documents: list[BaseRecord] = [
{ {
@@ -288,7 +288,7 @@ class TestChromaDBClient:
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None: def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
"""Test add_documents with documents that have no metadata.""" """Test add_documents with documents that have no metadata."""
mock_collection = Mock() mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [ documents: list[BaseRecord] = [
{"content": "Document without metadata"}, {"content": "Document without metadata"},
@@ -308,7 +308,7 @@ class TestChromaDBClient:
) -> None: ) -> None:
"""Test add_documents when all documents have no metadata.""" """Test add_documents when all documents have no metadata."""
mock_collection = Mock() mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [ documents: list[BaseRecord] = [
{"content": "Document 1"}, {"content": "Document 1"},
@@ -335,7 +335,7 @@ class TestChromaDBClient:
) -> None: ) -> None:
"""Test that aadd_documents adds documents to collection asynchronously.""" """Test that aadd_documents adds documents to collection asynchronously."""
mock_collection = AsyncMock() mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock( mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection return_value=mock_collection
) )
@@ -350,7 +350,7 @@ class TestChromaDBClient:
collection_name="test_collection", documents=documents collection_name="test_collection", documents=documents
) )
mock_async_chromadb_client.get_collection.assert_called_once_with( mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection", name="test_collection",
embedding_function=async_client.embedding_function, embedding_function=async_client.embedding_function,
) )
@@ -368,7 +368,7 @@ class TestChromaDBClient:
) -> None: ) -> None:
"""Test aadd_documents with custom document IDs.""" """Test aadd_documents with custom document IDs."""
mock_collection = AsyncMock() mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock( mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection return_value=mock_collection
) )
@@ -401,7 +401,7 @@ class TestChromaDBClient:
) -> None: ) -> None:
"""Test aadd_documents with documents that have no metadata.""" """Test aadd_documents with documents that have no metadata."""
mock_collection = AsyncMock() mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock( mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection return_value=mock_collection
) )
@@ -434,7 +434,7 @@ class TestChromaDBClient:
"""Test that search queries the collection correctly.""" """Test that search queries the collection correctly."""
mock_collection = Mock() mock_collection = Mock()
mock_collection.metadata = {"hnsw:space": "cosine"} mock_collection.metadata = {"hnsw:space": "cosine"}
mock_chromadb_client.get_collection.return_value = mock_collection mock_chromadb_client.get_or_create_collection.return_value = mock_collection
mock_collection.query.return_value = { mock_collection.query.return_value = {
"ids": [["doc1", "doc2"]], "ids": [["doc1", "doc2"]],
"documents": [["Document 1", "Document 2"]], "documents": [["Document 1", "Document 2"]],
@@ -444,13 +444,13 @@ class TestChromaDBClient:
results = client.search(collection_name="test_collection", query="test query") results = client.search(collection_name="test_collection", query="test query")
mock_chromadb_client.get_collection.assert_called_once_with( mock_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection", name="test_collection",
embedding_function=client.embedding_function, embedding_function=client.embedding_function,
) )
mock_collection.query.assert_called_once_with( mock_collection.query.assert_called_once_with(
query_texts=["test query"], query_texts=["test query"],
n_results=10, n_results=5,
where=None, where=None,
where_document=None, where_document=None,
include=["metadatas", "documents", "distances"], include=["metadatas", "documents", "distances"],
@@ -466,7 +466,7 @@ class TestChromaDBClient:
"""Test search with optional parameters.""" """Test search with optional parameters."""
mock_collection = Mock() mock_collection = Mock()
mock_collection.metadata = {"hnsw:space": "cosine"} mock_collection.metadata = {"hnsw:space": "cosine"}
mock_chromadb_client.get_collection.return_value = mock_collection mock_chromadb_client.get_or_create_collection.return_value = mock_collection
mock_collection.query.return_value = { mock_collection.query.return_value = {
"ids": [["doc1", "doc2", "doc3"]], "ids": [["doc1", "doc2", "doc3"]],
"documents": [["Document 1", "Document 2", "Document 3"]], "documents": [["Document 1", "Document 2", "Document 3"]],
@@ -499,7 +499,7 @@ class TestChromaDBClient:
"""Test that asearch queries the collection correctly.""" """Test that asearch queries the collection correctly."""
mock_collection = AsyncMock() mock_collection = AsyncMock()
mock_collection.metadata = {"hnsw:space": "cosine"} mock_collection.metadata = {"hnsw:space": "cosine"}
mock_async_chromadb_client.get_collection = AsyncMock( mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection return_value=mock_collection
) )
mock_collection.query = AsyncMock( mock_collection.query = AsyncMock(
@@ -515,13 +515,13 @@ class TestChromaDBClient:
collection_name="test_collection", query="test query" collection_name="test_collection", query="test query"
) )
mock_async_chromadb_client.get_collection.assert_called_once_with( mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection", name="test_collection",
embedding_function=async_client.embedding_function, embedding_function=async_client.embedding_function,
) )
mock_collection.query.assert_called_once_with( mock_collection.query.assert_called_once_with(
query_texts=["test query"], query_texts=["test query"],
n_results=10, n_results=5,
where=None, where=None,
where_document=None, where_document=None,
include=["metadatas", "documents", "distances"], include=["metadatas", "documents", "distances"],
@@ -540,7 +540,7 @@ class TestChromaDBClient:
"""Test asearch with optional parameters.""" """Test asearch with optional parameters."""
mock_collection = AsyncMock() mock_collection = AsyncMock()
mock_collection.metadata = {"hnsw:space": "cosine"} mock_collection.metadata = {"hnsw:space": "cosine"}
mock_async_chromadb_client.get_collection = AsyncMock( mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection return_value=mock_collection
) )
mock_collection.query = AsyncMock( mock_collection.query = AsyncMock(

View File

@@ -3,7 +3,8 @@
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
import pytest import pytest
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient from qdrant_client import AsyncQdrantClient
from qdrant_client import QdrantClient as SyncQdrantClient
from crewai.rag.core.exceptions import ClientMethodMismatchError from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.qdrant.client import QdrantClient from crewai.rag.qdrant.client import QdrantClient
@@ -435,7 +436,7 @@ class TestQdrantClient:
call_args = mock_qdrant_client.query_points.call_args call_args = mock_qdrant_client.query_points.call_args
assert call_args.kwargs["collection_name"] == "test_collection" assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3] assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
assert call_args.kwargs["limit"] == 10 assert call_args.kwargs["limit"] == 5
assert call_args.kwargs["with_payload"] is True assert call_args.kwargs["with_payload"] is True
assert call_args.kwargs["with_vectors"] is False assert call_args.kwargs["with_vectors"] is False
@@ -540,7 +541,7 @@ class TestQdrantClient:
call_args = mock_async_qdrant_client.query_points.call_args call_args = mock_async_qdrant_client.query_points.call_args
assert call_args.kwargs["collection_name"] == "test_collection" assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3] assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
assert call_args.kwargs["limit"] == 10 assert call_args.kwargs["limit"] == 5
assert call_args.kwargs["with_payload"] is True assert call_args.kwargs["with_payload"] is True
assert call_args.kwargs["with_vectors"] is False assert call_args.kwargs["with_vectors"] is False