diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index 3330ba6ce..fc5493ceb 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -43,7 +43,7 @@ class Knowledge(BaseModel): self.sources = sources def query( - self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35 + self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6 ) -> list[SearchResult]: """ Query across all knowledge sources to find the most relevant information. diff --git a/src/crewai/knowledge/knowledge_config.py b/src/crewai/knowledge/knowledge_config.py index e84341f6a..67f0ee44b 100644 --- a/src/crewai/knowledge/knowledge_config.py +++ b/src/crewai/knowledge/knowledge_config.py @@ -9,8 +9,8 @@ class KnowledgeConfig(BaseModel): score_threshold (float): The minimum score for a document to be considered relevant. """ - results_limit: int = Field(default=3, description="The number of results to return") + results_limit: int = Field(default=5, description="The number of results to return") score_threshold: float = Field( - default=0.35, + default=0.6, description="The minimum score for a result to be considered relevant", ) diff --git a/src/crewai/knowledge/storage/base_knowledge_storage.py b/src/crewai/knowledge/storage/base_knowledge_storage.py index 376ed6612..2bc63fb30 100644 --- a/src/crewai/knowledge/storage/base_knowledge_storage.py +++ b/src/crewai/knowledge/storage/base_knowledge_storage.py @@ -11,9 +11,9 @@ class BaseKnowledgeStorage(ABC): def search( self, query: list[str], - limit: int = 3, + limit: int = 5, metadata_filter: dict[str, Any] | None = None, - score_threshold: float = 0.35, + score_threshold: float = 0.6, ) -> list[SearchResult]: """Search for documents in the knowledge base.""" diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index 4aeb58e15..3eb70946f 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -1,4 +1,5 @@ import logging +import traceback import warnings from typing import Any, cast @@ -49,9 +50,9 @@ class KnowledgeStorage(BaseKnowledgeStorage): def search( self, query: list[str], - limit: int = 3, + limit: int = 5, metadata_filter: dict[str, Any] | None = None, - score_threshold: float = 0.35, + score_threshold: float = 0.6, ) -> list[SearchResult]: try: if not query: @@ -73,7 +74,9 @@ class KnowledgeStorage(BaseKnowledgeStorage): score_threshold=score_threshold, ) except Exception as e: - logging.error(f"Error during knowledge search: {e!s}") + logging.error( + f"Error during knowledge search: {e!s}\n{traceback.format_exc()}" + ) return [] def reset(self) -> None: @@ -86,7 +89,9 @@ class KnowledgeStorage(BaseKnowledgeStorage): ) client.delete_collection(collection_name=collection_name) except Exception as e: - logging.error(f"Error during knowledge reset: {e!s}") + logging.error( + f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}" + ) def save(self, documents: list[str]) -> None: try: diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index eed044d48..5176db882 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -1,20 +1,20 @@ -from typing import Any import time +from typing import Any from pydantic import PrivateAttr +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.memory_events import ( + MemoryQueryCompletedEvent, + MemoryQueryFailedEvent, + MemoryQueryStartedEvent, + MemorySaveCompletedEvent, + MemorySaveFailedEvent, + MemorySaveStartedEvent, +) from crewai.memory.entity.entity_memory_item import EntityMemoryItem from crewai.memory.memory import Memory from crewai.memory.storage.rag_storage import RAGStorage -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.memory_events import ( - MemoryQueryStartedEvent, - MemoryQueryCompletedEvent, - MemoryQueryFailedEvent, - MemorySaveStartedEvent, - MemorySaveCompletedEvent, - MemorySaveFailedEvent, -) class EntityMemory(Memory): @@ -31,10 +31,10 @@ class EntityMemory(Memory): if memory_provider == "mem0": try: from crewai.memory.storage.mem0_storage import Mem0Storage - except ImportError: + except ImportError as e: raise ImportError( "Mem0 is not installed. Please install it with `pip install mem0ai`." - ) + ) from e config = embedder_config.get("config") if embedder_config else None storage = Mem0Storage(type="short_term", crew=crew, config=config) else: @@ -90,23 +90,31 @@ class EntityMemory(Memory): saved_count = 0 errors = [] + def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]: + """Save a single item and return success status.""" + try: + if self._memory_provider == "mem0": + data = f""" + Remember details about the following entity: + Name: {item.name} + Type: {item.type} + Entity Description: {item.description} + """ + else: + data = f"{item.name}({item.type}): {item.description}" + + super(EntityMemory, self).save(data, item.metadata) + return True, None + except Exception as e: + return False, f"{item.name}: {e!s}" + try: for item in items: - try: - if self._memory_provider == "mem0": - data = f""" - Remember details about the following entity: - Name: {item.name} - Type: {item.type} - Entity Description: {item.description} - """ - else: - data = f"{item.name}({item.type}): {item.description}" - - super().save(data, item.metadata) + success, error = save_single_item(item) + if success: saved_count += 1 - except Exception as e: - errors.append(f"{item.name}: {str(e)}") + else: + errors.append(error) if is_batch: emit_value = f"Saved {saved_count} entities" @@ -153,8 +161,8 @@ class EntityMemory(Memory): def search( self, query: str, - limit: int = 3, - score_threshold: float = 0.35, + limit: int = 5, + score_threshold: float = 0.6, ): crewai_event_bus.emit( self, @@ -206,4 +214,6 @@ class EntityMemory(Memory): try: self.storage.reset() except Exception as e: - raise Exception(f"An error occurred while resetting the entity memory: {e}") + raise Exception( + f"An error occurred while resetting the entity memory: {e}" + ) from e diff --git a/src/crewai/memory/external/external_memory.py b/src/crewai/memory/external/external_memory.py index 7fbbea9a1..9cdb6872f 100644 --- a/src/crewai/memory/external/external_memory.py +++ b/src/crewai/memory/external/external_memory.py @@ -1,41 +1,41 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional import time +from typing import TYPE_CHECKING, Any +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.memory_events import ( + MemoryQueryCompletedEvent, + MemoryQueryFailedEvent, + MemoryQueryStartedEvent, + MemorySaveCompletedEvent, + MemorySaveFailedEvent, + MemorySaveStartedEvent, +) from crewai.memory.external.external_memory_item import ExternalMemoryItem from crewai.memory.memory import Memory from crewai.memory.storage.interface import Storage -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.memory_events import ( - MemoryQueryStartedEvent, - MemoryQueryCompletedEvent, - MemoryQueryFailedEvent, - MemorySaveStartedEvent, - MemorySaveCompletedEvent, - MemorySaveFailedEvent, -) if TYPE_CHECKING: from crewai.memory.storage.mem0_storage import Mem0Storage class ExternalMemory(Memory): - def __init__(self, storage: Optional[Storage] = None, **data: Any): + def __init__(self, storage: Storage | None = None, **data: Any): super().__init__(storage=storage, **data) @staticmethod - def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage": + def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage": from crewai.memory.storage.mem0_storage import Mem0Storage return Mem0Storage(type="external", crew=crew, config=config) @staticmethod - def external_supported_storages() -> Dict[str, Any]: + def external_supported_storages() -> dict[str, Any]: return { "mem0": ExternalMemory._configure_mem0, } @staticmethod - def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage: + def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage: if not embedder_config: raise ValueError("embedder_config is required") @@ -52,7 +52,7 @@ class ExternalMemory(Memory): def save( self, value: Any, - metadata: Optional[Dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ) -> None: """Saves a value into the external storage.""" crewai_event_bus.emit( @@ -103,8 +103,8 @@ class ExternalMemory(Memory): def search( self, query: str, - limit: int = 3, - score_threshold: float = 0.35, + limit: int = 5, + score_threshold: float = 0.6, ): crewai_event_bus.emit( self, diff --git a/src/crewai/memory/memory.py b/src/crewai/memory/memory.py index 2301c91e8..fae7cc2f1 100644 --- a/src/crewai/memory/memory.py +++ b/src/crewai/memory/memory.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional from pydantic import BaseModel @@ -12,8 +12,8 @@ class Memory(BaseModel): Base class for memory, now supporting agent tags and generic metadata. """ - embedder_config: Optional[Dict[str, Any]] = None - crew: Optional[Any] = None + embedder_config: dict[str, Any] | None = None + crew: Any | None = None storage: Any _agent: Optional["Agent"] = None @@ -45,7 +45,7 @@ class Memory(BaseModel): def save( self, value: Any, - metadata: Optional[Dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ) -> None: metadata = metadata or {} @@ -54,9 +54,9 @@ class Memory(BaseModel): def search( self, query: str, - limit: int = 3, - score_threshold: float = 0.35, - ) -> List[Any]: + limit: int = 5, + score_threshold: float = 0.6, + ) -> list[Any]: return self.storage.search( query=query, limit=limit, score_threshold=score_threshold ) diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index 97fd0b320..ceb6de82b 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -1,20 +1,20 @@ -from typing import Any, Dict, Optional import time +from typing import Any from pydantic import PrivateAttr +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.memory_events import ( + MemoryQueryCompletedEvent, + MemoryQueryFailedEvent, + MemoryQueryStartedEvent, + MemorySaveCompletedEvent, + MemorySaveFailedEvent, + MemorySaveStartedEvent, +) from crewai.memory.memory import Memory from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem from crewai.memory.storage.rag_storage import RAGStorage -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.memory_events import ( - MemoryQueryStartedEvent, - MemoryQueryCompletedEvent, - MemoryQueryFailedEvent, - MemorySaveStartedEvent, - MemorySaveCompletedEvent, - MemorySaveFailedEvent, -) class ShortTermMemory(Memory): @@ -26,17 +26,17 @@ class ShortTermMemory(Memory): MemoryItem instances. """ - _memory_provider: Optional[str] = PrivateAttr() + _memory_provider: str | None = PrivateAttr() def __init__(self, crew=None, embedder_config=None, storage=None, path=None): memory_provider = embedder_config.get("provider") if embedder_config else None if memory_provider == "mem0": try: from crewai.memory.storage.mem0_storage import Mem0Storage - except ImportError: + except ImportError as e: raise ImportError( "Mem0 is not installed. Please install it with `pip install mem0ai`." - ) + ) from e config = embedder_config.get("config") if embedder_config else None storage = Mem0Storage(type="short_term", crew=crew, config=config) else: @@ -56,7 +56,7 @@ class ShortTermMemory(Memory): def save( self, value: Any, - metadata: Optional[Dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ) -> None: crewai_event_bus.emit( self, @@ -112,8 +112,8 @@ class ShortTermMemory(Memory): def search( self, query: str, - limit: int = 3, - score_threshold: float = 0.35, + limit: int = 5, + score_threshold: float = 0.6, ): crewai_event_bus.emit( self, @@ -167,4 +167,4 @@ class ShortTermMemory(Memory): except Exception as e: raise Exception( f"An error occurred while resetting the short-term memory: {e}" - ) + ) from e diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index 128aa6ed8..036b9d2a4 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -151,7 +151,7 @@ class Mem0Storage(Storage): self.memory.add(conversations, **params) def search( - self, query: str, limit: int = 3, score_threshold: float = 0.35 + self, query: str, limit: int = 5, score_threshold: float = 0.6 ) -> list[Any]: params = { "query": query, diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index b52ec384e..7e66a262c 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -1,4 +1,5 @@ import logging +import traceback import warnings from typing import Any @@ -86,14 +87,16 @@ class RAGStorage(BaseRAGStorage): client.add_documents(collection_name=collection_name, documents=[document]) except Exception as e: - logging.error(f"Error during {self.type} save: {e!s}") + logging.error( + f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}" + ) def search( self, query: str, - limit: int = 3, + limit: int = 5, filter: dict[str, Any] | None = None, - score_threshold: float = 0.35, + score_threshold: float = 0.6, ) -> list[Any]: try: client = self._get_client() @@ -110,7 +113,9 @@ class RAGStorage(BaseRAGStorage): score_threshold=score_threshold, ) except Exception as e: - logging.error(f"Error during {self.type} search: {e!s}") + logging.error( + f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}" + ) return [] def reset(self) -> None: diff --git a/src/crewai/rag/chromadb/client.py b/src/crewai/rag/chromadb/client.py index 3a9d140d4..0caa4f39c 100644 --- a/src/crewai/rag/chromadb/client.py +++ b/src/crewai/rag/chromadb/client.py @@ -42,21 +42,29 @@ class ChromaDBClient(BaseClient): Attributes: client: ChromaDB client instance (ClientAPI or AsyncClientAPI). embedding_function: Function to generate embeddings for documents. + default_limit: Default number of results to return in searches. + default_score_threshold: Default minimum score for search results. """ def __init__( self, client: ChromaDBClientType, embedding_function: ChromaEmbeddingFunction, + default_limit: int = 5, + default_score_threshold: float = 0.6, ) -> None: """Initialize ChromaDBClient with client and embedding function. Args: client: Pre-configured ChromaDB client instance. embedding_function: Embedding function for text to vector conversion. + default_limit: Default number of results to return in searches. + default_score_threshold: Default minimum score for search results. """ self.client = client self.embedding_function = embedding_function + self.default_limit = default_limit + self.default_score_threshold = default_score_threshold def create_collection( self, **kwargs: Unpack[ChromaDBCollectionCreateParams] @@ -301,7 +309,7 @@ class ChromaDBClient(BaseClient): if not documents: raise ValueError("Documents list cannot be empty") - collection = self.client.get_collection( + collection = self.client.get_or_create_collection( name=_sanitize_collection_name(collection_name), embedding_function=self.embedding_function, ) @@ -345,7 +353,7 @@ class ChromaDBClient(BaseClient): if not documents: raise ValueError("Documents list cannot be empty") - collection = await self.client.get_collection( + collection = await self.client.get_or_create_collection( name=_sanitize_collection_name(collection_name), embedding_function=self.embedding_function, ) @@ -390,9 +398,14 @@ class ChromaDBClient(BaseClient): "Use asearch() for AsyncClientAPI." ) + if "limit" not in kwargs: + kwargs["limit"] = self.default_limit + if "score_threshold" not in kwargs: + kwargs["score_threshold"] = self.default_score_threshold + params = _extract_search_params(kwargs) - collection = self.client.get_collection( + collection = self.client.get_or_create_collection( name=_sanitize_collection_name(params.collection_name), embedding_function=self.embedding_function, ) @@ -448,9 +461,14 @@ class ChromaDBClient(BaseClient): "Use search() for ClientAPI." ) + if "limit" not in kwargs: + kwargs["limit"] = self.default_limit + if "score_threshold" not in kwargs: + kwargs["score_threshold"] = self.default_score_threshold + params = _extract_search_params(kwargs) - collection = await self.client.get_collection( + collection = await self.client.get_or_create_collection( name=_sanitize_collection_name(params.collection_name), embedding_function=self.embedding_function, ) diff --git a/src/crewai/rag/chromadb/factory.py b/src/crewai/rag/chromadb/factory.py index 44def6495..a02d350ac 100644 --- a/src/crewai/rag/chromadb/factory.py +++ b/src/crewai/rag/chromadb/factory.py @@ -39,4 +39,6 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient: return ChromaDBClient( client=client, embedding_function=config.embedding_function, + default_limit=config.limit, + default_score_threshold=config.score_threshold, ) diff --git a/src/crewai/rag/config/base.py b/src/crewai/rag/config/base.py index b287b6ea6..411c4f7bc 100644 --- a/src/crewai/rag/config/base.py +++ b/src/crewai/rag/config/base.py @@ -14,3 +14,5 @@ class BaseRagConfig: provider: SupportedProvider = field(init=False) embedding_function: Any | None = field(default=None) + limit: int = field(default=5) + score_threshold: float = field(default=0.6) diff --git a/src/crewai/rag/qdrant/client.py b/src/crewai/rag/qdrant/client.py index 3386d3411..c82ad9f8e 100644 --- a/src/crewai/rag/qdrant/client.py +++ b/src/crewai/rag/qdrant/client.py @@ -6,8 +6,8 @@ from typing_extensions import Unpack from crewai.rag.core.base_client import ( BaseClient, - BaseCollectionParams, BaseCollectionAddParams, + BaseCollectionParams, BaseCollectionSearchParams, ) from crewai.rag.core.exceptions import ClientMethodMismatchError @@ -18,11 +18,11 @@ from crewai.rag.qdrant.types import ( QdrantCollectionCreateParams, ) from crewai.rag.qdrant.utils import ( + _create_point_from_document, + _get_collection_params, _is_async_client, _is_async_embedding_function, _is_sync_client, - _create_point_from_document, - _get_collection_params, _prepare_search_params, _process_search_results, ) @@ -38,21 +38,29 @@ class QdrantClient(BaseClient): Attributes: client: Qdrant client instance (QdrantClient or AsyncQdrantClient). embedding_function: Function to generate embeddings for documents. + default_limit: Default number of results to return in searches. + default_score_threshold: Default minimum score for search results. """ def __init__( self, client: QdrantClientType, embedding_function: EmbeddingFunction | AsyncEmbeddingFunction, + default_limit: int = 5, + default_score_threshold: float = 0.6, ) -> None: """Initialize QdrantClient with client and embedding function. Args: client: Pre-configured Qdrant client instance. embedding_function: Embedding function for text to vector conversion. + default_limit: Default number of results to return in searches. + default_score_threshold: Default minimum score for search results. """ self.client = client self.embedding_function = embedding_function + self.default_limit = default_limit + self.default_score_threshold = default_score_threshold def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None: """Create a new collection in Qdrant. @@ -332,9 +340,9 @@ class QdrantClient(BaseClient): collection_name = kwargs["collection_name"] query = kwargs["query"] - limit = kwargs.get("limit", 10) + limit = kwargs.get("limit", self.default_limit) metadata_filter = kwargs.get("metadata_filter") - score_threshold = kwargs.get("score_threshold") + score_threshold = kwargs.get("score_threshold", self.default_score_threshold) if not self.client.collection_exists(collection_name): raise ValueError(f"Collection '{collection_name}' does not exist") @@ -387,9 +395,9 @@ class QdrantClient(BaseClient): collection_name = kwargs["collection_name"] query = kwargs["query"] - limit = kwargs.get("limit", 10) + limit = kwargs.get("limit", self.default_limit) metadata_filter = kwargs.get("metadata_filter") - score_threshold = kwargs.get("score_threshold") + score_threshold = kwargs.get("score_threshold", self.default_score_threshold) if not await self.client.collection_exists(collection_name): raise ValueError(f"Collection '{collection_name}' does not exist") diff --git a/src/crewai/rag/qdrant/factory.py b/src/crewai/rag/qdrant/factory.py index 75529a2a1..512e7a562 100644 --- a/src/crewai/rag/qdrant/factory.py +++ b/src/crewai/rag/qdrant/factory.py @@ -1,6 +1,7 @@ """Factory functions for creating Qdrant clients from configuration.""" from qdrant_client import QdrantClient as SyncQdrantClientBase + from crewai.rag.qdrant.client import QdrantClient from crewai.rag.qdrant.config import QdrantConfig @@ -17,5 +18,8 @@ def create_client(config: QdrantConfig) -> QdrantClient: qdrant_client = SyncQdrantClientBase(**config.options) return QdrantClient( - client=qdrant_client, embedding_function=config.embedding_function + client=qdrant_client, + embedding_function=config.embedding_function, + default_limit=config.limit, + default_score_threshold=config.score_threshold, ) diff --git a/src/crewai/rag/storage/base_rag_storage.py b/src/crewai/rag/storage/base_rag_storage.py index 36b4020b7..772ed4266 100644 --- a/src/crewai/rag/storage/base_rag_storage.py +++ b/src/crewai/rag/storage/base_rag_storage.py @@ -41,9 +41,9 @@ class BaseRAGStorage(ABC): def search( self, query: str, - limit: int = 3, + limit: int = 5, filter: dict[str, Any] | None = None, - score_threshold: float = 0.35, + score_threshold: float = 0.6, ) -> list[Any]: """Search for entries in the storage.""" diff --git a/tests/rag/chromadb/test_client.py b/tests/rag/chromadb/test_client.py index 8e0cc66a1..8fef2ff8d 100644 --- a/tests/rag/chromadb/test_client.py +++ b/tests/rag/chromadb/test_client.py @@ -236,7 +236,7 @@ class TestChromaDBClient: def test_add_documents(self, client, mock_chromadb_client) -> None: """Test that add_documents adds documents to collection.""" mock_collection = Mock() - mock_chromadb_client.get_collection.return_value = mock_collection + mock_chromadb_client.get_or_create_collection.return_value = mock_collection documents: list[BaseRecord] = [ { @@ -247,7 +247,7 @@ class TestChromaDBClient: client.add_documents(collection_name="test_collection", documents=documents) - mock_chromadb_client.get_collection.assert_called_once_with( + mock_chromadb_client.get_or_create_collection.assert_called_once_with( name="test_collection", embedding_function=client.embedding_function, ) @@ -262,7 +262,7 @@ class TestChromaDBClient: def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None: """Test add_documents with custom document IDs.""" mock_collection = Mock() - mock_chromadb_client.get_collection.return_value = mock_collection + mock_chromadb_client.get_or_create_collection.return_value = mock_collection documents: list[BaseRecord] = [ { @@ -288,7 +288,7 @@ class TestChromaDBClient: def test_add_documents_without_metadata(self, client, mock_chromadb_client) -> None: """Test add_documents with documents that have no metadata.""" mock_collection = Mock() - mock_chromadb_client.get_collection.return_value = mock_collection + mock_chromadb_client.get_or_create_collection.return_value = mock_collection documents: list[BaseRecord] = [ {"content": "Document without metadata"}, @@ -308,7 +308,7 @@ class TestChromaDBClient: ) -> None: """Test add_documents when all documents have no metadata.""" mock_collection = Mock() - mock_chromadb_client.get_collection.return_value = mock_collection + mock_chromadb_client.get_or_create_collection.return_value = mock_collection documents: list[BaseRecord] = [ {"content": "Document 1"}, @@ -335,7 +335,7 @@ class TestChromaDBClient: ) -> None: """Test that aadd_documents adds documents to collection asynchronously.""" mock_collection = AsyncMock() - mock_async_chromadb_client.get_collection = AsyncMock( + mock_async_chromadb_client.get_or_create_collection = AsyncMock( return_value=mock_collection ) @@ -350,7 +350,7 @@ class TestChromaDBClient: collection_name="test_collection", documents=documents ) - mock_async_chromadb_client.get_collection.assert_called_once_with( + mock_async_chromadb_client.get_or_create_collection.assert_called_once_with( name="test_collection", embedding_function=async_client.embedding_function, ) @@ -368,7 +368,7 @@ class TestChromaDBClient: ) -> None: """Test aadd_documents with custom document IDs.""" mock_collection = AsyncMock() - mock_async_chromadb_client.get_collection = AsyncMock( + mock_async_chromadb_client.get_or_create_collection = AsyncMock( return_value=mock_collection ) @@ -401,7 +401,7 @@ class TestChromaDBClient: ) -> None: """Test aadd_documents with documents that have no metadata.""" mock_collection = AsyncMock() - mock_async_chromadb_client.get_collection = AsyncMock( + mock_async_chromadb_client.get_or_create_collection = AsyncMock( return_value=mock_collection ) @@ -434,7 +434,7 @@ class TestChromaDBClient: """Test that search queries the collection correctly.""" mock_collection = Mock() mock_collection.metadata = {"hnsw:space": "cosine"} - mock_chromadb_client.get_collection.return_value = mock_collection + mock_chromadb_client.get_or_create_collection.return_value = mock_collection mock_collection.query.return_value = { "ids": [["doc1", "doc2"]], "documents": [["Document 1", "Document 2"]], @@ -444,13 +444,13 @@ class TestChromaDBClient: results = client.search(collection_name="test_collection", query="test query") - mock_chromadb_client.get_collection.assert_called_once_with( + mock_chromadb_client.get_or_create_collection.assert_called_once_with( name="test_collection", embedding_function=client.embedding_function, ) mock_collection.query.assert_called_once_with( query_texts=["test query"], - n_results=10, + n_results=5, where=None, where_document=None, include=["metadatas", "documents", "distances"], @@ -466,7 +466,7 @@ class TestChromaDBClient: """Test search with optional parameters.""" mock_collection = Mock() mock_collection.metadata = {"hnsw:space": "cosine"} - mock_chromadb_client.get_collection.return_value = mock_collection + mock_chromadb_client.get_or_create_collection.return_value = mock_collection mock_collection.query.return_value = { "ids": [["doc1", "doc2", "doc3"]], "documents": [["Document 1", "Document 2", "Document 3"]], @@ -499,7 +499,7 @@ class TestChromaDBClient: """Test that asearch queries the collection correctly.""" mock_collection = AsyncMock() mock_collection.metadata = {"hnsw:space": "cosine"} - mock_async_chromadb_client.get_collection = AsyncMock( + mock_async_chromadb_client.get_or_create_collection = AsyncMock( return_value=mock_collection ) mock_collection.query = AsyncMock( @@ -515,13 +515,13 @@ class TestChromaDBClient: collection_name="test_collection", query="test query" ) - mock_async_chromadb_client.get_collection.assert_called_once_with( + mock_async_chromadb_client.get_or_create_collection.assert_called_once_with( name="test_collection", embedding_function=async_client.embedding_function, ) mock_collection.query.assert_called_once_with( query_texts=["test query"], - n_results=10, + n_results=5, where=None, where_document=None, include=["metadatas", "documents", "distances"], @@ -540,7 +540,7 @@ class TestChromaDBClient: """Test asearch with optional parameters.""" mock_collection = AsyncMock() mock_collection.metadata = {"hnsw:space": "cosine"} - mock_async_chromadb_client.get_collection = AsyncMock( + mock_async_chromadb_client.get_or_create_collection = AsyncMock( return_value=mock_collection ) mock_collection.query = AsyncMock( diff --git a/tests/rag/qdrant/test_client.py b/tests/rag/qdrant/test_client.py index a1c16e9bc..9984dce8a 100644 --- a/tests/rag/qdrant/test_client.py +++ b/tests/rag/qdrant/test_client.py @@ -3,7 +3,8 @@ from unittest.mock import AsyncMock, Mock import pytest -from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient +from qdrant_client import AsyncQdrantClient +from qdrant_client import QdrantClient as SyncQdrantClient from crewai.rag.core.exceptions import ClientMethodMismatchError from crewai.rag.qdrant.client import QdrantClient @@ -435,7 +436,7 @@ class TestQdrantClient: call_args = mock_qdrant_client.query_points.call_args assert call_args.kwargs["collection_name"] == "test_collection" assert call_args.kwargs["query"] == [0.1, 0.2, 0.3] - assert call_args.kwargs["limit"] == 10 + assert call_args.kwargs["limit"] == 5 assert call_args.kwargs["with_payload"] is True assert call_args.kwargs["with_vectors"] is False @@ -540,7 +541,7 @@ class TestQdrantClient: call_args = mock_async_qdrant_client.query_points.call_args assert call_args.kwargs["collection_name"] == "test_collection" assert call_args.kwargs["query"] == [0.1, 0.2, 0.3] - assert call_args.kwargs["limit"] == 10 + assert call_args.kwargs["limit"] == 5 assert call_args.kwargs["with_payload"] is True assert call_args.kwargs["with_vectors"] is False