mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
feat: add configurable search parameters for RAG, knowledge, and memory (#3531)
- Add limit and score_threshold to BaseRagConfig, propagate to clients - Update default search params in RAG storage, knowledge, and memory (limit=5, threshold=0.6) - Fix linting (ruff, mypy, PERF203) and refactor save logic - Update tests for new defaults and ChromaDB behavior
This commit is contained in:
@@ -43,7 +43,7 @@ class Knowledge(BaseModel):
|
||||
self.sources = sources
|
||||
|
||||
def query(
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
|
||||
@@ -9,8 +9,8 @@ class KnowledgeConfig(BaseModel):
|
||||
score_threshold (float): The minimum score for a document to be considered relevant.
|
||||
"""
|
||||
|
||||
results_limit: int = Field(default=3, description="The number of results to return")
|
||||
results_limit: int = Field(default=5, description="The number of results to return")
|
||||
score_threshold: float = Field(
|
||||
default=0.35,
|
||||
default=0.6,
|
||||
description="The minimum score for a result to be considered relevant",
|
||||
)
|
||||
|
||||
@@ -11,9 +11,9 @@ class BaseKnowledgeStorage(ABC):
|
||||
def search(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 3,
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base."""
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -49,9 +50,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
def search(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 3,
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
try:
|
||||
if not query:
|
||||
@@ -73,7 +74,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during knowledge search: {e!s}")
|
||||
logging.error(
|
||||
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -86,7 +89,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
client.delete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during knowledge reset: {e!s}")
|
||||
logging.error(
|
||||
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def save(self, documents: list[str]) -> None:
|
||||
try:
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Any
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
@@ -31,10 +31,10 @@ class EntityMemory(Memory):
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
) from e
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
@@ -90,23 +90,31 @@ class EntityMemory(Memory):
|
||||
saved_count = 0
|
||||
errors = []
|
||||
|
||||
def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||
"""Save a single item and return success status."""
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super(EntityMemory, self).save(data, item.metadata)
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, f"{item.name}: {e!s}"
|
||||
|
||||
try:
|
||||
for item in items:
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super().save(data, item.metadata)
|
||||
success, error = save_single_item(item)
|
||||
if success:
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
else:
|
||||
errors.append(error)
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
@@ -153,8 +161,8 @@ class EntityMemory(Memory):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -206,4 +214,6 @@ class EntityMemory(Memory):
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while resetting the entity memory: {e}")
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the entity memory: {e}"
|
||||
) from e
|
||||
|
||||
34
src/crewai/memory/external/external_memory.py
vendored
34
src/crewai/memory/external/external_memory.py
vendored
@@ -1,41 +1,41 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
|
||||
class ExternalMemory(Memory):
|
||||
def __init__(self, storage: Optional[Storage] = None, **data: Any):
|
||||
def __init__(self, storage: Storage | None = None, **data: Any):
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
@staticmethod
|
||||
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
|
||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
return Mem0Storage(type="external", crew=crew, config=config)
|
||||
|
||||
@staticmethod
|
||||
def external_supported_storages() -> Dict[str, Any]:
|
||||
def external_supported_storages() -> dict[str, Any]:
|
||||
return {
|
||||
"mem0": ExternalMemory._configure_mem0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
|
||||
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
|
||||
if not embedder_config:
|
||||
raise ValueError("embedder_config is required")
|
||||
|
||||
@@ -52,7 +52,7 @@ class ExternalMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Saves a value into the external storage."""
|
||||
crewai_event_bus.emit(
|
||||
@@ -103,8 +103,8 @@ class ExternalMemory(Memory):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -12,8 +12,8 @@ class Memory(BaseModel):
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
crew: Optional[Any] = None
|
||||
embedder_config: dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
|
||||
storage: Any
|
||||
_agent: Optional["Agent"] = None
|
||||
@@ -45,7 +45,7 @@ class Memory(BaseModel):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
|
||||
@@ -54,9 +54,9 @@ class Memory(BaseModel):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
@@ -26,17 +26,17 @@ class ShortTermMemory(Memory):
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
) from e
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
@@ -56,7 +56,7 @@ class ShortTermMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -112,8 +112,8 @@ class ShortTermMemory(Memory):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -167,4 +167,4 @@ class ShortTermMemory(Memory):
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the short-term memory: {e}"
|
||||
)
|
||||
) from e
|
||||
|
||||
@@ -151,7 +151,7 @@ class Mem0Storage(Storage):
|
||||
self.memory.add(conversations, **params)
|
||||
|
||||
def search(
|
||||
self, query: str, limit: int = 3, score_threshold: float = 0.35
|
||||
self, query: str, limit: int = 5, score_threshold: float = 0.6
|
||||
) -> list[Any]:
|
||||
params = {
|
||||
"query": query,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
@@ -86,14 +87,16 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
client.add_documents(collection_name=collection_name, documents=[document])
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {e!s}")
|
||||
logging.error(
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
try:
|
||||
client = self._get_client()
|
||||
@@ -110,7 +113,9 @@ class RAGStorage(BaseRAGStorage):
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {e!s}")
|
||||
logging.error(
|
||||
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -42,21 +42,29 @@ class ChromaDBClient(BaseClient):
|
||||
Attributes:
|
||||
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChromaDBClientType,
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured ChromaDB client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
@@ -301,7 +309,7 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_collection(
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
@@ -345,7 +353,7 @@ class ChromaDBClient(BaseClient):
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
@@ -390,9 +398,14 @@ class ChromaDBClient(BaseClient):
|
||||
"Use asearch() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
if "limit" not in kwargs:
|
||||
kwargs["limit"] = self.default_limit
|
||||
if "score_threshold" not in kwargs:
|
||||
kwargs["score_threshold"] = self.default_score_threshold
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = self.client.get_collection(
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
@@ -448,9 +461,14 @@ class ChromaDBClient(BaseClient):
|
||||
"Use search() for ClientAPI."
|
||||
)
|
||||
|
||||
if "limit" not in kwargs:
|
||||
kwargs["limit"] = self.default_limit
|
||||
if "score_threshold" not in kwargs:
|
||||
kwargs["score_threshold"] = self.default_score_threshold
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
@@ -39,4 +39,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
return ChromaDBClient(
|
||||
client=client,
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
)
|
||||
|
||||
@@ -14,3 +14,5 @@ class BaseRagConfig:
|
||||
|
||||
provider: SupportedProvider = field(init=False)
|
||||
embedding_function: Any | None = field(default=None)
|
||||
limit: int = field(default=5)
|
||||
score_threshold: float = field(default=0.6)
|
||||
|
||||
@@ -6,8 +6,8 @@ from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
@@ -18,11 +18,11 @@ from crewai.rag.qdrant.types import (
|
||||
QdrantCollectionCreateParams,
|
||||
)
|
||||
from crewai.rag.qdrant.utils import (
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_is_async_client,
|
||||
_is_async_embedding_function,
|
||||
_is_sync_client,
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_prepare_search_params,
|
||||
_process_search_results,
|
||||
)
|
||||
@@ -38,21 +38,29 @@ class QdrantClient(BaseClient):
|
||||
Attributes:
|
||||
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: QdrantClientType,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
) -> None:
|
||||
"""Initialize QdrantClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured Qdrant client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||
"""Create a new collection in Qdrant.
|
||||
@@ -332,9 +340,9 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
limit = kwargs.get("limit", self.default_limit)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
@@ -387,9 +395,9 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
limit = kwargs.get("limit", self.default_limit)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Factory functions for creating Qdrant clients from configuration."""
|
||||
|
||||
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
||||
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
@@ -17,5 +18,8 @@ def create_client(config: QdrantConfig) -> QdrantClient:
|
||||
|
||||
qdrant_client = SyncQdrantClientBase(**config.options)
|
||||
return QdrantClient(
|
||||
client=qdrant_client, embedding_function=config.embedding_function
|
||||
client=qdrant_client,
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
)
|
||||
|
||||
@@ -41,9 +41,9 @@ class BaseRAGStorage(ABC):
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for entries in the storage."""
|
||||
|
||||
|
||||
@@ -236,7 +236,7 @@ class TestChromaDBClient:
|
||||
def test_add_documents(self, client, mock_chromadb_client) -> None:
|
||||
"""Test that add_documents adds documents to collection."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
@@ -247,7 +247,7 @@ class TestChromaDBClient:
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=client.embedding_function,
|
||||
)
|
||||
@@ -262,7 +262,7 @@ class TestChromaDBClient:
|
||||
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with custom document IDs."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
@@ -288,7 +288,7 @@ class TestChromaDBClient:
|
||||
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
|
||||
"""Test add_documents with documents that have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document without metadata"},
|
||||
@@ -308,7 +308,7 @@ class TestChromaDBClient:
|
||||
) -> None:
|
||||
"""Test add_documents when all documents have no metadata."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document 1"},
|
||||
@@ -335,7 +335,7 @@ class TestChromaDBClient:
|
||||
) -> None:
|
||||
"""Test that aadd_documents adds documents to collection asynchronously."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
@@ -350,7 +350,7 @@ class TestChromaDBClient:
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=async_client.embedding_function,
|
||||
)
|
||||
@@ -368,7 +368,7 @@ class TestChromaDBClient:
|
||||
) -> None:
|
||||
"""Test aadd_documents with custom document IDs."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
@@ -401,7 +401,7 @@ class TestChromaDBClient:
|
||||
) -> None:
|
||||
"""Test aadd_documents with documents that have no metadata."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
@@ -434,7 +434,7 @@ class TestChromaDBClient:
|
||||
"""Test that search queries the collection correctly."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["doc1", "doc2"]],
|
||||
"documents": [["Document 1", "Document 2"]],
|
||||
@@ -444,13 +444,13 @@ class TestChromaDBClient:
|
||||
|
||||
results = client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
mock_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=client.embedding_function,
|
||||
)
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=10,
|
||||
n_results=5,
|
||||
where=None,
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
@@ -466,7 +466,7 @@ class TestChromaDBClient:
|
||||
"""Test search with optional parameters."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["doc1", "doc2", "doc3"]],
|
||||
"documents": [["Document 1", "Document 2", "Document 3"]],
|
||||
@@ -499,7 +499,7 @@ class TestChromaDBClient:
|
||||
"""Test that asearch queries the collection correctly."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_collection.query = AsyncMock(
|
||||
@@ -515,13 +515,13 @@ class TestChromaDBClient:
|
||||
collection_name="test_collection", query="test query"
|
||||
)
|
||||
|
||||
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
||||
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||
name="test_collection",
|
||||
embedding_function=async_client.embedding_function,
|
||||
)
|
||||
mock_collection.query.assert_called_once_with(
|
||||
query_texts=["test query"],
|
||||
n_results=10,
|
||||
n_results=5,
|
||||
where=None,
|
||||
where_document=None,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
@@ -540,7 +540,7 @@ class TestChromaDBClient:
|
||||
"""Test asearch with optional parameters."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
mock_collection.query = AsyncMock(
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client import QdrantClient as SyncQdrantClient
|
||||
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
@@ -435,7 +436,7 @@ class TestQdrantClient:
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 10
|
||||
assert call_args.kwargs["limit"] == 5
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
@@ -540,7 +541,7 @@ class TestQdrantClient:
|
||||
call_args = mock_async_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 10
|
||||
assert call_args.kwargs["limit"] == 5
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user