mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
feat: add configurable search parameters for RAG, knowledge, and memory (#3531)
- Add limit and score_threshold to BaseRagConfig, propagate to clients - Update default search params in RAG storage, knowledge, and memory (limit=5, threshold=0.6) - Fix linting (ruff, mypy, PERF203) and refactor save logic - Update tests for new defaults and ChromaDB behavior
This commit is contained in:
@@ -43,7 +43,7 @@ class Knowledge(BaseModel):
|
|||||||
self.sources = sources
|
self.sources = sources
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Query across all knowledge sources to find the most relevant information.
|
Query across all knowledge sources to find the most relevant information.
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ class KnowledgeConfig(BaseModel):
|
|||||||
score_threshold (float): The minimum score for a document to be considered relevant.
|
score_threshold (float): The minimum score for a document to be considered relevant.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results_limit: int = Field(default=3, description="The number of results to return")
|
results_limit: int = Field(default=5, description="The number of results to return")
|
||||||
score_threshold: float = Field(
|
score_threshold: float = Field(
|
||||||
default=0.35,
|
default=0.6,
|
||||||
description="The minimum score for a result to be considered relevant",
|
description="The minimum score for a result to be considered relevant",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ class BaseKnowledgeStorage(ABC):
|
|||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: list[str],
|
query: list[str],
|
||||||
limit: int = 3,
|
limit: int = 5,
|
||||||
metadata_filter: dict[str, Any] | None = None,
|
metadata_filter: dict[str, Any] | None = None,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.6,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""Search for documents in the knowledge base."""
|
"""Search for documents in the knowledge base."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
@@ -49,9 +50,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: list[str],
|
query: list[str],
|
||||||
limit: int = 3,
|
limit: int = 5,
|
||||||
metadata_filter: dict[str, Any] | None = None,
|
metadata_filter: dict[str, Any] | None = None,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.6,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
try:
|
try:
|
||||||
if not query:
|
if not query:
|
||||||
@@ -73,7 +74,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error during knowledge search: {e!s}")
|
logging.error(
|
||||||
|
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
@@ -86,7 +89,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
)
|
)
|
||||||
client.delete_collection(collection_name=collection_name)
|
client.delete_collection(collection_name=collection_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error during knowledge reset: {e!s}")
|
logging.error(
|
||||||
|
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
|
||||||
def save(self, documents: list[str]) -> None:
|
def save(self, documents: list[str]) -> None:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,20 +1,20 @@
|
|||||||
from typing import Any
|
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import PrivateAttr
|
from pydantic import PrivateAttr
|
||||||
|
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.memory_events import (
|
||||||
|
MemoryQueryCompletedEvent,
|
||||||
|
MemoryQueryFailedEvent,
|
||||||
|
MemoryQueryStartedEvent,
|
||||||
|
MemorySaveCompletedEvent,
|
||||||
|
MemorySaveFailedEvent,
|
||||||
|
MemorySaveStartedEvent,
|
||||||
|
)
|
||||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||||
from crewai.memory.memory import Memory
|
from crewai.memory.memory import Memory
|
||||||
from crewai.memory.storage.rag_storage import RAGStorage
|
from crewai.memory.storage.rag_storage import RAGStorage
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
|
||||||
from crewai.events.types.memory_events import (
|
|
||||||
MemoryQueryStartedEvent,
|
|
||||||
MemoryQueryCompletedEvent,
|
|
||||||
MemoryQueryFailedEvent,
|
|
||||||
MemorySaveStartedEvent,
|
|
||||||
MemorySaveCompletedEvent,
|
|
||||||
MemorySaveFailedEvent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EntityMemory(Memory):
|
class EntityMemory(Memory):
|
||||||
@@ -31,10 +31,10 @@ class EntityMemory(Memory):
|
|||||||
if memory_provider == "mem0":
|
if memory_provider == "mem0":
|
||||||
try:
|
try:
|
||||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||||
)
|
) from e
|
||||||
config = embedder_config.get("config") if embedder_config else None
|
config = embedder_config.get("config") if embedder_config else None
|
||||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||||
else:
|
else:
|
||||||
@@ -90,8 +90,8 @@ class EntityMemory(Memory):
|
|||||||
saved_count = 0
|
saved_count = 0
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
try:
|
def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||||
for item in items:
|
"""Save a single item and return success status."""
|
||||||
try:
|
try:
|
||||||
if self._memory_provider == "mem0":
|
if self._memory_provider == "mem0":
|
||||||
data = f"""
|
data = f"""
|
||||||
@@ -103,10 +103,18 @@ class EntityMemory(Memory):
|
|||||||
else:
|
else:
|
||||||
data = f"{item.name}({item.type}): {item.description}"
|
data = f"{item.name}({item.type}): {item.description}"
|
||||||
|
|
||||||
super().save(data, item.metadata)
|
super(EntityMemory, self).save(data, item.metadata)
|
||||||
saved_count += 1
|
return True, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.append(f"{item.name}: {str(e)}")
|
return False, f"{item.name}: {e!s}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
for item in items:
|
||||||
|
success, error = save_single_item(item)
|
||||||
|
if success:
|
||||||
|
saved_count += 1
|
||||||
|
else:
|
||||||
|
errors.append(error)
|
||||||
|
|
||||||
if is_batch:
|
if is_batch:
|
||||||
emit_value = f"Saved {saved_count} entities"
|
emit_value = f"Saved {saved_count} entities"
|
||||||
@@ -153,8 +161,8 @@ class EntityMemory(Memory):
|
|||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.6,
|
||||||
):
|
):
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -206,4 +214,6 @@ class EntityMemory(Memory):
|
|||||||
try:
|
try:
|
||||||
self.storage.reset()
|
self.storage.reset()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"An error occurred while resetting the entity memory: {e}")
|
raise Exception(
|
||||||
|
f"An error occurred while resetting the entity memory: {e}"
|
||||||
|
) from e
|
||||||
|
|||||||
34
src/crewai/memory/external/external_memory.py
vendored
34
src/crewai/memory/external/external_memory.py
vendored
@@ -1,41 +1,41 @@
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
||||||
import time
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.memory_events import (
|
||||||
|
MemoryQueryCompletedEvent,
|
||||||
|
MemoryQueryFailedEvent,
|
||||||
|
MemoryQueryStartedEvent,
|
||||||
|
MemorySaveCompletedEvent,
|
||||||
|
MemorySaveFailedEvent,
|
||||||
|
MemorySaveStartedEvent,
|
||||||
|
)
|
||||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||||
from crewai.memory.memory import Memory
|
from crewai.memory.memory import Memory
|
||||||
from crewai.memory.storage.interface import Storage
|
from crewai.memory.storage.interface import Storage
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
|
||||||
from crewai.events.types.memory_events import (
|
|
||||||
MemoryQueryStartedEvent,
|
|
||||||
MemoryQueryCompletedEvent,
|
|
||||||
MemoryQueryFailedEvent,
|
|
||||||
MemorySaveStartedEvent,
|
|
||||||
MemorySaveCompletedEvent,
|
|
||||||
MemorySaveFailedEvent,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||||
|
|
||||||
|
|
||||||
class ExternalMemory(Memory):
|
class ExternalMemory(Memory):
|
||||||
def __init__(self, storage: Optional[Storage] = None, **data: Any):
|
def __init__(self, storage: Storage | None = None, **data: Any):
|
||||||
super().__init__(storage=storage, **data)
|
super().__init__(storage=storage, **data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
|
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
|
||||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||||
|
|
||||||
return Mem0Storage(type="external", crew=crew, config=config)
|
return Mem0Storage(type="external", crew=crew, config=config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def external_supported_storages() -> Dict[str, Any]:
|
def external_supported_storages() -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"mem0": ExternalMemory._configure_mem0,
|
"mem0": ExternalMemory._configure_mem0,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
|
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
|
||||||
if not embedder_config:
|
if not embedder_config:
|
||||||
raise ValueError("embedder_config is required")
|
raise ValueError("embedder_config is required")
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ class ExternalMemory(Memory):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
value: Any,
|
value: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Saves a value into the external storage."""
|
"""Saves a value into the external storage."""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
@@ -103,8 +103,8 @@ class ExternalMemory(Memory):
|
|||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.6,
|
||||||
):
|
):
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -12,8 +12,8 @@ class Memory(BaseModel):
|
|||||||
Base class for memory, now supporting agent tags and generic metadata.
|
Base class for memory, now supporting agent tags and generic metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embedder_config: Optional[Dict[str, Any]] = None
|
embedder_config: dict[str, Any] | None = None
|
||||||
crew: Optional[Any] = None
|
crew: Any | None = None
|
||||||
|
|
||||||
storage: Any
|
storage: Any
|
||||||
_agent: Optional["Agent"] = None
|
_agent: Optional["Agent"] = None
|
||||||
@@ -45,7 +45,7 @@ class Memory(BaseModel):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
value: Any,
|
value: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
metadata = metadata or {}
|
metadata = metadata or {}
|
||||||
|
|
||||||
@@ -54,9 +54,9 @@ class Memory(BaseModel):
|
|||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.6,
|
||||||
) -> List[Any]:
|
) -> list[Any]:
|
||||||
return self.storage.search(
|
return self.storage.search(
|
||||||
query=query, limit=limit, score_threshold=score_threshold
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,20 +1,20 @@
|
|||||||
from typing import Any, Dict, Optional
|
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import PrivateAttr
|
from pydantic import PrivateAttr
|
||||||
|
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.memory_events import (
|
||||||
|
MemoryQueryCompletedEvent,
|
||||||
|
MemoryQueryFailedEvent,
|
||||||
|
MemoryQueryStartedEvent,
|
||||||
|
MemorySaveCompletedEvent,
|
||||||
|
MemorySaveFailedEvent,
|
||||||
|
MemorySaveStartedEvent,
|
||||||
|
)
|
||||||
from crewai.memory.memory import Memory
|
from crewai.memory.memory import Memory
|
||||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||||
from crewai.memory.storage.rag_storage import RAGStorage
|
from crewai.memory.storage.rag_storage import RAGStorage
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
|
||||||
from crewai.events.types.memory_events import (
|
|
||||||
MemoryQueryStartedEvent,
|
|
||||||
MemoryQueryCompletedEvent,
|
|
||||||
MemoryQueryFailedEvent,
|
|
||||||
MemorySaveStartedEvent,
|
|
||||||
MemorySaveCompletedEvent,
|
|
||||||
MemorySaveFailedEvent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ShortTermMemory(Memory):
|
class ShortTermMemory(Memory):
|
||||||
@@ -26,17 +26,17 @@ class ShortTermMemory(Memory):
|
|||||||
MemoryItem instances.
|
MemoryItem instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_memory_provider: Optional[str] = PrivateAttr()
|
_memory_provider: str | None = PrivateAttr()
|
||||||
|
|
||||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||||
if memory_provider == "mem0":
|
if memory_provider == "mem0":
|
||||||
try:
|
try:
|
||||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||||
)
|
) from e
|
||||||
config = embedder_config.get("config") if embedder_config else None
|
config = embedder_config.get("config") if embedder_config else None
|
||||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||||
else:
|
else:
|
||||||
@@ -56,7 +56,7 @@ class ShortTermMemory(Memory):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
value: Any,
|
value: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -112,8 +112,8 @@ class ShortTermMemory(Memory):
|
|||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.6,
|
||||||
):
|
):
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -167,4 +167,4 @@ class ShortTermMemory(Memory):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"An error occurred while resetting the short-term memory: {e}"
|
f"An error occurred while resetting the short-term memory: {e}"
|
||||||
)
|
) from e
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ class Mem0Storage(Storage):
|
|||||||
self.memory.add(conversations, **params)
|
self.memory.add(conversations, **params)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, query: str, limit: int = 3, score_threshold: float = 0.35
|
self, query: str, limit: int = 5, score_threshold: float = 0.6
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
params = {
|
params = {
|
||||||
"query": query,
|
"query": query,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -86,14 +87,16 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
|
|
||||||
client.add_documents(collection_name=collection_name, documents=[document])
|
client.add_documents(collection_name=collection_name, documents=[document])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error during {self.type} save: {e!s}")
|
logging.error(
|
||||||
|
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 5,
|
||||||
filter: dict[str, Any] | None = None,
|
filter: dict[str, Any] | None = None,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.6,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
@@ -110,7 +113,9 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error during {self.type} search: {e!s}")
|
logging.error(
|
||||||
|
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
|||||||
@@ -42,21 +42,29 @@ class ChromaDBClient(BaseClient):
|
|||||||
Attributes:
|
Attributes:
|
||||||
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
|
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
|
||||||
embedding_function: Function to generate embeddings for documents.
|
embedding_function: Function to generate embeddings for documents.
|
||||||
|
default_limit: Default number of results to return in searches.
|
||||||
|
default_score_threshold: Default minimum score for search results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: ChromaDBClientType,
|
client: ChromaDBClientType,
|
||||||
embedding_function: ChromaEmbeddingFunction,
|
embedding_function: ChromaEmbeddingFunction,
|
||||||
|
default_limit: int = 5,
|
||||||
|
default_score_threshold: float = 0.6,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize ChromaDBClient with client and embedding function.
|
"""Initialize ChromaDBClient with client and embedding function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client: Pre-configured ChromaDB client instance.
|
client: Pre-configured ChromaDB client instance.
|
||||||
embedding_function: Embedding function for text to vector conversion.
|
embedding_function: Embedding function for text to vector conversion.
|
||||||
|
default_limit: Default number of results to return in searches.
|
||||||
|
default_score_threshold: Default minimum score for search results.
|
||||||
"""
|
"""
|
||||||
self.client = client
|
self.client = client
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
|
self.default_limit = default_limit
|
||||||
|
self.default_score_threshold = default_score_threshold
|
||||||
|
|
||||||
def create_collection(
|
def create_collection(
|
||||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||||
@@ -301,7 +309,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
if not documents:
|
if not documents:
|
||||||
raise ValueError("Documents list cannot be empty")
|
raise ValueError("Documents list cannot be empty")
|
||||||
|
|
||||||
collection = self.client.get_collection(
|
collection = self.client.get_or_create_collection(
|
||||||
name=_sanitize_collection_name(collection_name),
|
name=_sanitize_collection_name(collection_name),
|
||||||
embedding_function=self.embedding_function,
|
embedding_function=self.embedding_function,
|
||||||
)
|
)
|
||||||
@@ -345,7 +353,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
if not documents:
|
if not documents:
|
||||||
raise ValueError("Documents list cannot be empty")
|
raise ValueError("Documents list cannot be empty")
|
||||||
|
|
||||||
collection = await self.client.get_collection(
|
collection = await self.client.get_or_create_collection(
|
||||||
name=_sanitize_collection_name(collection_name),
|
name=_sanitize_collection_name(collection_name),
|
||||||
embedding_function=self.embedding_function,
|
embedding_function=self.embedding_function,
|
||||||
)
|
)
|
||||||
@@ -390,9 +398,14 @@ class ChromaDBClient(BaseClient):
|
|||||||
"Use asearch() for AsyncClientAPI."
|
"Use asearch() for AsyncClientAPI."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "limit" not in kwargs:
|
||||||
|
kwargs["limit"] = self.default_limit
|
||||||
|
if "score_threshold" not in kwargs:
|
||||||
|
kwargs["score_threshold"] = self.default_score_threshold
|
||||||
|
|
||||||
params = _extract_search_params(kwargs)
|
params = _extract_search_params(kwargs)
|
||||||
|
|
||||||
collection = self.client.get_collection(
|
collection = self.client.get_or_create_collection(
|
||||||
name=_sanitize_collection_name(params.collection_name),
|
name=_sanitize_collection_name(params.collection_name),
|
||||||
embedding_function=self.embedding_function,
|
embedding_function=self.embedding_function,
|
||||||
)
|
)
|
||||||
@@ -448,9 +461,14 @@ class ChromaDBClient(BaseClient):
|
|||||||
"Use search() for ClientAPI."
|
"Use search() for ClientAPI."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "limit" not in kwargs:
|
||||||
|
kwargs["limit"] = self.default_limit
|
||||||
|
if "score_threshold" not in kwargs:
|
||||||
|
kwargs["score_threshold"] = self.default_score_threshold
|
||||||
|
|
||||||
params = _extract_search_params(kwargs)
|
params = _extract_search_params(kwargs)
|
||||||
|
|
||||||
collection = await self.client.get_collection(
|
collection = await self.client.get_or_create_collection(
|
||||||
name=_sanitize_collection_name(params.collection_name),
|
name=_sanitize_collection_name(params.collection_name),
|
||||||
embedding_function=self.embedding_function,
|
embedding_function=self.embedding_function,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,4 +39,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
|||||||
return ChromaDBClient(
|
return ChromaDBClient(
|
||||||
client=client,
|
client=client,
|
||||||
embedding_function=config.embedding_function,
|
embedding_function=config.embedding_function,
|
||||||
|
default_limit=config.limit,
|
||||||
|
default_score_threshold=config.score_threshold,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,3 +14,5 @@ class BaseRagConfig:
|
|||||||
|
|
||||||
provider: SupportedProvider = field(init=False)
|
provider: SupportedProvider = field(init=False)
|
||||||
embedding_function: Any | None = field(default=None)
|
embedding_function: Any | None = field(default=None)
|
||||||
|
limit: int = field(default=5)
|
||||||
|
score_threshold: float = field(default=0.6)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from typing_extensions import Unpack
|
|||||||
|
|
||||||
from crewai.rag.core.base_client import (
|
from crewai.rag.core.base_client import (
|
||||||
BaseClient,
|
BaseClient,
|
||||||
BaseCollectionParams,
|
|
||||||
BaseCollectionAddParams,
|
BaseCollectionAddParams,
|
||||||
|
BaseCollectionParams,
|
||||||
BaseCollectionSearchParams,
|
BaseCollectionSearchParams,
|
||||||
)
|
)
|
||||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||||
@@ -18,11 +18,11 @@ from crewai.rag.qdrant.types import (
|
|||||||
QdrantCollectionCreateParams,
|
QdrantCollectionCreateParams,
|
||||||
)
|
)
|
||||||
from crewai.rag.qdrant.utils import (
|
from crewai.rag.qdrant.utils import (
|
||||||
|
_create_point_from_document,
|
||||||
|
_get_collection_params,
|
||||||
_is_async_client,
|
_is_async_client,
|
||||||
_is_async_embedding_function,
|
_is_async_embedding_function,
|
||||||
_is_sync_client,
|
_is_sync_client,
|
||||||
_create_point_from_document,
|
|
||||||
_get_collection_params,
|
|
||||||
_prepare_search_params,
|
_prepare_search_params,
|
||||||
_process_search_results,
|
_process_search_results,
|
||||||
)
|
)
|
||||||
@@ -38,21 +38,29 @@ class QdrantClient(BaseClient):
|
|||||||
Attributes:
|
Attributes:
|
||||||
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
|
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
|
||||||
embedding_function: Function to generate embeddings for documents.
|
embedding_function: Function to generate embeddings for documents.
|
||||||
|
default_limit: Default number of results to return in searches.
|
||||||
|
default_score_threshold: Default minimum score for search results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: QdrantClientType,
|
client: QdrantClientType,
|
||||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||||
|
default_limit: int = 5,
|
||||||
|
default_score_threshold: float = 0.6,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize QdrantClient with client and embedding function.
|
"""Initialize QdrantClient with client and embedding function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client: Pre-configured Qdrant client instance.
|
client: Pre-configured Qdrant client instance.
|
||||||
embedding_function: Embedding function for text to vector conversion.
|
embedding_function: Embedding function for text to vector conversion.
|
||||||
|
default_limit: Default number of results to return in searches.
|
||||||
|
default_score_threshold: Default minimum score for search results.
|
||||||
"""
|
"""
|
||||||
self.client = client
|
self.client = client
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
|
self.default_limit = default_limit
|
||||||
|
self.default_score_threshold = default_score_threshold
|
||||||
|
|
||||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||||
"""Create a new collection in Qdrant.
|
"""Create a new collection in Qdrant.
|
||||||
@@ -332,9 +340,9 @@ class QdrantClient(BaseClient):
|
|||||||
|
|
||||||
collection_name = kwargs["collection_name"]
|
collection_name = kwargs["collection_name"]
|
||||||
query = kwargs["query"]
|
query = kwargs["query"]
|
||||||
limit = kwargs.get("limit", 10)
|
limit = kwargs.get("limit", self.default_limit)
|
||||||
metadata_filter = kwargs.get("metadata_filter")
|
metadata_filter = kwargs.get("metadata_filter")
|
||||||
score_threshold = kwargs.get("score_threshold")
|
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||||
|
|
||||||
if not self.client.collection_exists(collection_name):
|
if not self.client.collection_exists(collection_name):
|
||||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||||
@@ -387,9 +395,9 @@ class QdrantClient(BaseClient):
|
|||||||
|
|
||||||
collection_name = kwargs["collection_name"]
|
collection_name = kwargs["collection_name"]
|
||||||
query = kwargs["query"]
|
query = kwargs["query"]
|
||||||
limit = kwargs.get("limit", 10)
|
limit = kwargs.get("limit", self.default_limit)
|
||||||
metadata_filter = kwargs.get("metadata_filter")
|
metadata_filter = kwargs.get("metadata_filter")
|
||||||
score_threshold = kwargs.get("score_threshold")
|
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||||
|
|
||||||
if not await self.client.collection_exists(collection_name):
|
if not await self.client.collection_exists(collection_name):
|
||||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Factory functions for creating Qdrant clients from configuration."""
|
"""Factory functions for creating Qdrant clients from configuration."""
|
||||||
|
|
||||||
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
||||||
|
|
||||||
from crewai.rag.qdrant.client import QdrantClient
|
from crewai.rag.qdrant.client import QdrantClient
|
||||||
from crewai.rag.qdrant.config import QdrantConfig
|
from crewai.rag.qdrant.config import QdrantConfig
|
||||||
|
|
||||||
@@ -17,5 +18,8 @@ def create_client(config: QdrantConfig) -> QdrantClient:
|
|||||||
|
|
||||||
qdrant_client = SyncQdrantClientBase(**config.options)
|
qdrant_client = SyncQdrantClientBase(**config.options)
|
||||||
return QdrantClient(
|
return QdrantClient(
|
||||||
client=qdrant_client, embedding_function=config.embedding_function
|
client=qdrant_client,
|
||||||
|
embedding_function=config.embedding_function,
|
||||||
|
default_limit=config.limit,
|
||||||
|
default_score_threshold=config.score_threshold,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -41,9 +41,9 @@ class BaseRAGStorage(ABC):
|
|||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 5,
|
||||||
filter: dict[str, Any] | None = None,
|
filter: dict[str, Any] | None = None,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.6,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
"""Search for entries in the storage."""
|
"""Search for entries in the storage."""
|
||||||
|
|
||||||
|
|||||||
@@ -236,7 +236,7 @@ class TestChromaDBClient:
|
|||||||
def test_add_documents(self, client, mock_chromadb_client) -> None:
|
def test_add_documents(self, client, mock_chromadb_client) -> None:
|
||||||
"""Test that add_documents adds documents to collection."""
|
"""Test that add_documents adds documents to collection."""
|
||||||
mock_collection = Mock()
|
mock_collection = Mock()
|
||||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
documents: list[BaseRecord] = [
|
documents: list[BaseRecord] = [
|
||||||
{
|
{
|
||||||
@@ -247,7 +247,7 @@ class TestChromaDBClient:
|
|||||||
|
|
||||||
client.add_documents(collection_name="test_collection", documents=documents)
|
client.add_documents(collection_name="test_collection", documents=documents)
|
||||||
|
|
||||||
mock_chromadb_client.get_collection.assert_called_once_with(
|
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||||
name="test_collection",
|
name="test_collection",
|
||||||
embedding_function=client.embedding_function,
|
embedding_function=client.embedding_function,
|
||||||
)
|
)
|
||||||
@@ -262,7 +262,7 @@ class TestChromaDBClient:
|
|||||||
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
|
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
|
||||||
"""Test add_documents with custom document IDs."""
|
"""Test add_documents with custom document IDs."""
|
||||||
mock_collection = Mock()
|
mock_collection = Mock()
|
||||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
documents: list[BaseRecord] = [
|
documents: list[BaseRecord] = [
|
||||||
{
|
{
|
||||||
@@ -288,7 +288,7 @@ class TestChromaDBClient:
|
|||||||
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
|
def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None:
|
||||||
"""Test add_documents with documents that have no metadata."""
|
"""Test add_documents with documents that have no metadata."""
|
||||||
mock_collection = Mock()
|
mock_collection = Mock()
|
||||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
documents: list[BaseRecord] = [
|
documents: list[BaseRecord] = [
|
||||||
{"content": "Document without metadata"},
|
{"content": "Document without metadata"},
|
||||||
@@ -308,7 +308,7 @@ class TestChromaDBClient:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test add_documents when all documents have no metadata."""
|
"""Test add_documents when all documents have no metadata."""
|
||||||
mock_collection = Mock()
|
mock_collection = Mock()
|
||||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
documents: list[BaseRecord] = [
|
documents: list[BaseRecord] = [
|
||||||
{"content": "Document 1"},
|
{"content": "Document 1"},
|
||||||
@@ -335,7 +335,7 @@ class TestChromaDBClient:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test that aadd_documents adds documents to collection asynchronously."""
|
"""Test that aadd_documents adds documents to collection asynchronously."""
|
||||||
mock_collection = AsyncMock()
|
mock_collection = AsyncMock()
|
||||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||||
return_value=mock_collection
|
return_value=mock_collection
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ class TestChromaDBClient:
|
|||||||
collection_name="test_collection", documents=documents
|
collection_name="test_collection", documents=documents
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||||
name="test_collection",
|
name="test_collection",
|
||||||
embedding_function=async_client.embedding_function,
|
embedding_function=async_client.embedding_function,
|
||||||
)
|
)
|
||||||
@@ -368,7 +368,7 @@ class TestChromaDBClient:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test aadd_documents with custom document IDs."""
|
"""Test aadd_documents with custom document IDs."""
|
||||||
mock_collection = AsyncMock()
|
mock_collection = AsyncMock()
|
||||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||||
return_value=mock_collection
|
return_value=mock_collection
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ class TestChromaDBClient:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test aadd_documents with documents that have no metadata."""
|
"""Test aadd_documents with documents that have no metadata."""
|
||||||
mock_collection = AsyncMock()
|
mock_collection = AsyncMock()
|
||||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||||
return_value=mock_collection
|
return_value=mock_collection
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -434,7 +434,7 @@ class TestChromaDBClient:
|
|||||||
"""Test that search queries the collection correctly."""
|
"""Test that search queries the collection correctly."""
|
||||||
mock_collection = Mock()
|
mock_collection = Mock()
|
||||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||||
mock_collection.query.return_value = {
|
mock_collection.query.return_value = {
|
||||||
"ids": [["doc1", "doc2"]],
|
"ids": [["doc1", "doc2"]],
|
||||||
"documents": [["Document 1", "Document 2"]],
|
"documents": [["Document 1", "Document 2"]],
|
||||||
@@ -444,13 +444,13 @@ class TestChromaDBClient:
|
|||||||
|
|
||||||
results = client.search(collection_name="test_collection", query="test query")
|
results = client.search(collection_name="test_collection", query="test query")
|
||||||
|
|
||||||
mock_chromadb_client.get_collection.assert_called_once_with(
|
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||||
name="test_collection",
|
name="test_collection",
|
||||||
embedding_function=client.embedding_function,
|
embedding_function=client.embedding_function,
|
||||||
)
|
)
|
||||||
mock_collection.query.assert_called_once_with(
|
mock_collection.query.assert_called_once_with(
|
||||||
query_texts=["test query"],
|
query_texts=["test query"],
|
||||||
n_results=10,
|
n_results=5,
|
||||||
where=None,
|
where=None,
|
||||||
where_document=None,
|
where_document=None,
|
||||||
include=["metadatas", "documents", "distances"],
|
include=["metadatas", "documents", "distances"],
|
||||||
@@ -466,7 +466,7 @@ class TestChromaDBClient:
|
|||||||
"""Test search with optional parameters."""
|
"""Test search with optional parameters."""
|
||||||
mock_collection = Mock()
|
mock_collection = Mock()
|
||||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||||
mock_chromadb_client.get_collection.return_value = mock_collection
|
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||||
mock_collection.query.return_value = {
|
mock_collection.query.return_value = {
|
||||||
"ids": [["doc1", "doc2", "doc3"]],
|
"ids": [["doc1", "doc2", "doc3"]],
|
||||||
"documents": [["Document 1", "Document 2", "Document 3"]],
|
"documents": [["Document 1", "Document 2", "Document 3"]],
|
||||||
@@ -499,7 +499,7 @@ class TestChromaDBClient:
|
|||||||
"""Test that asearch queries the collection correctly."""
|
"""Test that asearch queries the collection correctly."""
|
||||||
mock_collection = AsyncMock()
|
mock_collection = AsyncMock()
|
||||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||||
return_value=mock_collection
|
return_value=mock_collection
|
||||||
)
|
)
|
||||||
mock_collection.query = AsyncMock(
|
mock_collection.query = AsyncMock(
|
||||||
@@ -515,13 +515,13 @@ class TestChromaDBClient:
|
|||||||
collection_name="test_collection", query="test query"
|
collection_name="test_collection", query="test query"
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||||
name="test_collection",
|
name="test_collection",
|
||||||
embedding_function=async_client.embedding_function,
|
embedding_function=async_client.embedding_function,
|
||||||
)
|
)
|
||||||
mock_collection.query.assert_called_once_with(
|
mock_collection.query.assert_called_once_with(
|
||||||
query_texts=["test query"],
|
query_texts=["test query"],
|
||||||
n_results=10,
|
n_results=5,
|
||||||
where=None,
|
where=None,
|
||||||
where_document=None,
|
where_document=None,
|
||||||
include=["metadatas", "documents", "distances"],
|
include=["metadatas", "documents", "distances"],
|
||||||
@@ -540,7 +540,7 @@ class TestChromaDBClient:
|
|||||||
"""Test asearch with optional parameters."""
|
"""Test asearch with optional parameters."""
|
||||||
mock_collection = AsyncMock()
|
mock_collection = AsyncMock()
|
||||||
mock_collection.metadata = {"hnsw:space": "cosine"}
|
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||||
mock_async_chromadb_client.get_collection = AsyncMock(
|
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||||
return_value=mock_collection
|
return_value=mock_collection
|
||||||
)
|
)
|
||||||
mock_collection.query = AsyncMock(
|
mock_collection.query = AsyncMock(
|
||||||
|
|||||||
@@ -3,7 +3,8 @@
|
|||||||
from unittest.mock import AsyncMock, Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
from qdrant_client import AsyncQdrantClient
|
||||||
|
from qdrant_client import QdrantClient as SyncQdrantClient
|
||||||
|
|
||||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||||
from crewai.rag.qdrant.client import QdrantClient
|
from crewai.rag.qdrant.client import QdrantClient
|
||||||
@@ -435,7 +436,7 @@ class TestQdrantClient:
|
|||||||
call_args = mock_qdrant_client.query_points.call_args
|
call_args = mock_qdrant_client.query_points.call_args
|
||||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||||
assert call_args.kwargs["limit"] == 10
|
assert call_args.kwargs["limit"] == 5
|
||||||
assert call_args.kwargs["with_payload"] is True
|
assert call_args.kwargs["with_payload"] is True
|
||||||
assert call_args.kwargs["with_vectors"] is False
|
assert call_args.kwargs["with_vectors"] is False
|
||||||
|
|
||||||
@@ -540,7 +541,7 @@ class TestQdrantClient:
|
|||||||
call_args = mock_async_qdrant_client.query_points.call_args
|
call_args = mock_async_qdrant_client.query_points.call_args
|
||||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||||
assert call_args.kwargs["limit"] == 10
|
assert call_args.kwargs["limit"] == 5
|
||||||
assert call_args.kwargs["with_payload"] is True
|
assert call_args.kwargs["with_payload"] is True
|
||||||
assert call_args.kwargs["with_vectors"] is False
|
assert call_args.kwargs["with_vectors"] is False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user