mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 23:32:39 +00:00
feat: add configurable search parameters for RAG, knowledge, and memory (#3531)
- Add limit and score_threshold to BaseRagConfig, propagate to clients - Update default search params in RAG storage, knowledge, and memory (limit=5, threshold=0.6) - Fix linting (ruff, mypy, PERF203) and refactor save logic - Update tests for new defaults and ChromaDB behavior
This commit is contained in:
@@ -6,8 +6,8 @@ from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
@@ -18,11 +18,11 @@ from crewai.rag.qdrant.types import (
|
||||
QdrantCollectionCreateParams,
|
||||
)
|
||||
from crewai.rag.qdrant.utils import (
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_is_async_client,
|
||||
_is_async_embedding_function,
|
||||
_is_sync_client,
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_prepare_search_params,
|
||||
_process_search_results,
|
||||
)
|
||||
@@ -38,21 +38,29 @@ class QdrantClient(BaseClient):
|
||||
Attributes:
|
||||
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: QdrantClientType,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
) -> None:
|
||||
"""Initialize QdrantClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured Qdrant client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||
"""Create a new collection in Qdrant.
|
||||
@@ -332,9 +340,9 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
limit = kwargs.get("limit", self.default_limit)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
@@ -387,9 +395,9 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
limit = kwargs.get("limit", self.default_limit)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
Reference in New Issue
Block a user