feat: add configurable search parameters for RAG, knowledge, and memory (#3531)

- Add limit and score_threshold to BaseRagConfig, propagate to clients  
- Update default search params in RAG storage, knowledge, and memory (limit=5, threshold=0.6)  
- Fix linting (ruff, mypy, PERF203) and refactor save logic  
- Update tests for new defaults and ChromaDB behavior
This commit is contained in:
Greyson LaLonde
2025-09-18 16:58:03 -04:00
committed by GitHub
parent 578fa8c2e4
commit d4aa676195
18 changed files with 173 additions and 118 deletions

View File

@@ -42,21 +42,29 @@ class ChromaDBClient(BaseClient):
Attributes:
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
def __init__(
self,
client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
Args:
client: Pre-configured ChromaDB client instance.
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -301,7 +309,7 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = self.client.get_collection(
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
@@ -345,7 +353,7 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = await self.client.get_collection(
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
@@ -390,9 +398,14 @@ class ChromaDBClient(BaseClient):
"Use asearch() for AsyncClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = self.client.get_collection(
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
@@ -448,9 +461,14 @@ class ChromaDBClient(BaseClient):
"Use search() for ClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = await self.client.get_collection(
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)

View File

@@ -39,4 +39,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
return ChromaDBClient(
client=client,
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
)

View File

@@ -14,3 +14,5 @@ class BaseRagConfig:
provider: SupportedProvider = field(init=False)
embedding_function: Any | None = field(default=None)
limit: int = field(default=5)
score_threshold: float = field(default=0.6)

View File

@@ -6,8 +6,8 @@ from typing_extensions import Unpack
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
BaseCollectionParams,
BaseCollectionSearchParams,
)
from crewai.rag.core.exceptions import ClientMethodMismatchError
@@ -18,11 +18,11 @@ from crewai.rag.qdrant.types import (
QdrantCollectionCreateParams,
)
from crewai.rag.qdrant.utils import (
_create_point_from_document,
_get_collection_params,
_is_async_client,
_is_async_embedding_function,
_is_sync_client,
_create_point_from_document,
_get_collection_params,
_prepare_search_params,
_process_search_results,
)
@@ -38,21 +38,29 @@ class QdrantClient(BaseClient):
Attributes:
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
def __init__(
self,
client: QdrantClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
) -> None:
"""Initialize QdrantClient with client and embedding function.
Args:
client: Pre-configured Qdrant client instance.
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant.
@@ -332,9 +340,9 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
@@ -387,9 +395,9 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", 10)
limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold")
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")

View File

@@ -1,6 +1,7 @@
"""Factory functions for creating Qdrant clients from configuration."""
from qdrant_client import QdrantClient as SyncQdrantClientBase
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
@@ -17,5 +18,8 @@ def create_client(config: QdrantConfig) -> QdrantClient:
qdrant_client = SyncQdrantClientBase(**config.options)
return QdrantClient(
client=qdrant_client, embedding_function=config.embedding_function
client=qdrant_client,
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
)

View File

@@ -41,9 +41,9 @@ class BaseRAGStorage(ABC):
def search(
self,
query: str,
limit: int = 3,
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for entries in the storage."""