feat: add configurable search parameters for RAG, knowledge, and memory (#3531)

- Add limit and score_threshold to BaseRagConfig, propagate to clients  
- Update default search params in RAG storage, knowledge, and memory (limit=5, threshold=0.6)  
- Fix linting (ruff, mypy, PERF203) and refactor save logic  
- Update tests for new defaults and ChromaDB behavior
This commit is contained in:
Greyson LaLonde
2025-09-18 16:58:03 -04:00
committed by GitHub
parent 578fa8c2e4
commit d4aa676195
18 changed files with 173 additions and 118 deletions

View File

@@ -42,21 +42,29 @@ class ChromaDBClient(BaseClient):
Attributes:
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
def __init__(
self,
client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
Args:
client: Pre-configured ChromaDB client instance.
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -301,7 +309,7 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = self.client.get_collection(
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
@@ -345,7 +353,7 @@ class ChromaDBClient(BaseClient):
if not documents:
raise ValueError("Documents list cannot be empty")
collection = await self.client.get_collection(
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
@@ -390,9 +398,14 @@ class ChromaDBClient(BaseClient):
"Use asearch() for AsyncClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = self.client.get_collection(
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
@@ -448,9 +461,14 @@ class ChromaDBClient(BaseClient):
"Use search() for ClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = await self.client.get_collection(
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)