mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 23:32:39 +00:00
feat: add configurable search parameters for RAG, knowledge, and memory (#3531)
- Add limit and score_threshold to BaseRagConfig, propagate to clients - Update default search params in RAG storage, knowledge, and memory (limit=5, threshold=0.6) - Fix linting (ruff, mypy, PERF203) and refactor save logic - Update tests for new defaults and ChromaDB behavior
This commit is contained in:
@@ -42,21 +42,29 @@ class ChromaDBClient(BaseClient):
|
||||
Attributes:
|
||||
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChromaDBClientType,
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured ChromaDB client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
@@ -301,7 +309,7 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_collection(
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
@@ -345,7 +353,7 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
@@ -390,9 +398,14 @@ class ChromaDBClient(BaseClient):
|
||||
"Use asearch() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
if "limit" not in kwargs:
|
||||
kwargs["limit"] = self.default_limit
|
||||
if "score_threshold" not in kwargs:
|
||||
kwargs["score_threshold"] = self.default_score_threshold
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = self.client.get_collection(
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
@@ -448,9 +461,14 @@ class ChromaDBClient(BaseClient):
|
||||
"Use search() for ClientAPI."
|
||||
)
|
||||
|
||||
if "limit" not in kwargs:
|
||||
kwargs["limit"] = self.default_limit
|
||||
if "score_threshold" not in kwargs:
|
||||
kwargs["score_threshold"] = self.default_score_threshold
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
@@ -39,4 +39,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
return ChromaDBClient(
|
||||
client=client,
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
)
|
||||
|
||||
@@ -14,3 +14,5 @@ class BaseRagConfig:
|
||||
|
||||
provider: SupportedProvider = field(init=False)
|
||||
embedding_function: Any | None = field(default=None)
|
||||
limit: int = field(default=5)
|
||||
score_threshold: float = field(default=0.6)
|
||||
|
||||
@@ -6,8 +6,8 @@ from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
@@ -18,11 +18,11 @@ from crewai.rag.qdrant.types import (
|
||||
QdrantCollectionCreateParams,
|
||||
)
|
||||
from crewai.rag.qdrant.utils import (
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_is_async_client,
|
||||
_is_async_embedding_function,
|
||||
_is_sync_client,
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_prepare_search_params,
|
||||
_process_search_results,
|
||||
)
|
||||
@@ -38,21 +38,29 @@ class QdrantClient(BaseClient):
|
||||
Attributes:
|
||||
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: QdrantClientType,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
) -> None:
|
||||
"""Initialize QdrantClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured Qdrant client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||
"""Create a new collection in Qdrant.
|
||||
@@ -332,9 +340,9 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
limit = kwargs.get("limit", self.default_limit)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
@@ -387,9 +395,9 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
limit = kwargs.get("limit", self.default_limit)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Factory functions for creating Qdrant clients from configuration."""
|
||||
|
||||
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
||||
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
@@ -17,5 +18,8 @@ def create_client(config: QdrantConfig) -> QdrantClient:
|
||||
|
||||
qdrant_client = SyncQdrantClientBase(**config.options)
|
||||
return QdrantClient(
|
||||
client=qdrant_client, embedding_function=config.embedding_function
|
||||
client=qdrant_client,
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
)
|
||||
|
||||
@@ -41,9 +41,9 @@ class BaseRAGStorage(ABC):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for entries in the storage."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user