diff --git a/lib/crewai/src/crewai/knowledge/knowledge.py b/lib/crewai/src/crewai/knowledge/knowledge.py index cb53ab3d6..eceef8b99 100644 --- a/lib/crewai/src/crewai/knowledge/knowledge.py +++ b/lib/crewai/src/crewai/knowledge/knowledge.py @@ -32,8 +32,8 @@ class Knowledge(BaseModel): sources: list[BaseKnowledgeSource], embedder: EmbedderConfig | None = None, storage: KnowledgeStorage | None = None, - **data, - ): + **data: object, + ) -> None: super().__init__(**data) if storage: self.storage = storage @@ -75,3 +75,44 @@ class Knowledge(BaseModel): self.storage.reset() else: raise ValueError("Storage is not initialized.") + + async def aquery( + self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6 + ) -> list[SearchResult]: + """Query across all knowledge sources asynchronously. + + Args: + query: List of query strings. + results_limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + + Returns: + The top results matching the query. + + Raises: + ValueError: If storage is not initialized. + """ + if self.storage is None: + raise ValueError("Storage is not initialized.") + + return await self.storage.asearch( + query, + limit=results_limit, + score_threshold=score_threshold, + ) + + async def aadd_sources(self) -> None: + """Add all knowledge sources to storage asynchronously.""" + try: + for source in self.sources: + source.storage = self.storage + await source.aadd() + except Exception as e: + raise e + + async def areset(self) -> None: + """Reset the knowledge base asynchronously.""" + if self.storage: + await self.storage.areset() + else: + raise ValueError("Storage is not initialized.") diff --git a/lib/crewai/src/crewai/knowledge/source/base_file_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/base_file_knowledge_source.py index 42af18736..0832717c1 100644 --- a/lib/crewai/src/crewai/knowledge/source/base_file_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/base_file_knowledge_source.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from pathlib import Path +from typing import Any from pydantic import Field, field_validator @@ -25,7 +26,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): safe_file_paths: list[Path] = Field(default_factory=list) @field_validator("file_path", "file_paths", mode="before") - def validate_file_path(cls, v, info): # noqa: N805 + @classmethod + def validate_file_path( + cls, v: Path | list[Path] | str | list[str] | None, info: Any + ) -> Path | list[Path] | str | list[str] | None: """Validate that at least one of file_path or file_paths is provided.""" # Single check if both are None, O(1) instead of nested conditions if ( @@ -38,7 +42,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): raise ValueError("Either file_path or file_paths must be provided") return v - def model_post_init(self, _): + def model_post_init(self, _: Any) -> None: """Post-initialization method to load content.""" self.safe_file_paths = self._process_file_paths() self.validate_content() @@ -48,7 +52,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): def load_content(self) -> dict[Path, str]: """Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory.""" - def validate_content(self): + def validate_content(self) -> None: """Validate the paths.""" for path in self.safe_file_paths: if not path.exists(): @@ -65,13 +69,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): color="red", ) - def _save_documents(self): + def _save_documents(self) -> None: """Save the documents to the storage.""" if self.storage: self.storage.save(self.chunks) else: raise ValueError("No storage found to save documents.") + async def _asave_documents(self) -> None: + """Save the documents to the storage asynchronously.""" + if self.storage: + await self.storage.asave(self.chunks) + else: + raise ValueError("No storage found to save documents.") + def convert_to_path(self, path: Path | str) -> Path: """Convert a path to a Path object.""" return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path diff --git a/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py index b62dd0f04..34774ce82 100644 --- a/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py @@ -39,12 +39,32 @@ class BaseKnowledgeSource(BaseModel, ABC): for i in range(0, len(text), self.chunk_size - self.chunk_overlap) ] - def _save_documents(self): - """ - Save the documents to the storage. + def _save_documents(self) -> None: + """Save the documents to the storage. + This method should be called after the chunks and embeddings are generated. + + Raises: + ValueError: If no storage is configured. """ if self.storage: self.storage.save(self.chunks) else: raise ValueError("No storage found to save documents.") + + @abstractmethod + async def aadd(self) -> None: + """Process content, chunk it, compute embeddings, and save them asynchronously.""" + + async def _asave_documents(self) -> None: + """Save the documents to the storage asynchronously. + + This method should be called after the chunks and embeddings are generated. + + Raises: + ValueError: If no storage is configured. + """ + if self.storage: + await self.storage.asave(self.chunks) + else: + raise ValueError("No storage found to save documents.") diff --git a/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py b/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py index 9061fe3fd..3dddacfac 100644 --- a/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py +++ b/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py @@ -2,27 +2,24 @@ from __future__ import annotations from collections.abc import Iterator from pathlib import Path +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse try: - from docling.datamodel.base_models import ( # type: ignore[import-not-found] - InputFormat, - ) - from docling.document_converter import ( # type: ignore[import-not-found] - DocumentConverter, - ) - from docling.exceptions import ConversionError # type: ignore[import-not-found] - from docling_core.transforms.chunker.hierarchical_chunker import ( # type: ignore[import-not-found] - HierarchicalChunker, - ) - from docling_core.types.doc.document import ( # type: ignore[import-not-found] - DoclingDocument, - ) + from docling.datamodel.base_models import InputFormat + from docling.document_converter import DocumentConverter + from docling.exceptions import ConversionError + from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker + from docling_core.types.doc.document import DoclingDocument DOCLING_AVAILABLE = True except ImportError: DOCLING_AVAILABLE = False + # Provide type stubs for when docling is not available + if TYPE_CHECKING: + from docling.document_converter import DocumentConverter + from docling_core.types.doc.document import DoclingDocument from pydantic import Field @@ -32,11 +29,13 @@ from crewai.utilities.logger import Logger class CrewDoclingSource(BaseKnowledgeSource): - """Default Source class for converting documents to markdown or json - This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth. + """Default Source class for converting documents to markdown or json. + + This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without + any additional dependencies and follows the docling package as the source of truth. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: if not DOCLING_AVAILABLE: raise ImportError( "The docling package is required to use CrewDoclingSource. " @@ -66,7 +65,7 @@ class CrewDoclingSource(BaseKnowledgeSource): ) ) - def model_post_init(self, _) -> None: + def model_post_init(self, _: Any) -> None: if self.file_path: self._logger.log( "warning", @@ -99,6 +98,15 @@ class CrewDoclingSource(BaseKnowledgeSource): self.chunks.extend(list(new_chunks_iterable)) self._save_documents() + async def aadd(self) -> None: + """Add docling content asynchronously.""" + if self.content is None: + return + for doc in self.content: + new_chunks_iterable = self._chunk_doc(doc) + self.chunks.extend(list(new_chunks_iterable)) + await self._asave_documents() + def _convert_source_to_docling_documents(self) -> list[DoclingDocument]: conv_results_iter = self.document_converter.convert_all(self.safe_file_paths) return [result.document for result in conv_results_iter] diff --git a/lib/crewai/src/crewai/knowledge/source/csv_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/csv_knowledge_source.py index dc7401598..7da82c3e3 100644 --- a/lib/crewai/src/crewai/knowledge/source/csv_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/csv_knowledge_source.py @@ -31,6 +31,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource): self.chunks.extend(new_chunks) self._save_documents() + async def aadd(self) -> None: + """Add CSV file content asynchronously.""" + content_str = ( + str(self.content) if isinstance(self.content, dict) else self.content + ) + new_chunks = self._chunk_text(content_str) + self.chunks.extend(new_chunks) + await self._asave_documents() + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ diff --git a/lib/crewai/src/crewai/knowledge/source/excel_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/excel_knowledge_source.py index 3c33e8803..ece582053 100644 --- a/lib/crewai/src/crewai/knowledge/source/excel_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/excel_knowledge_source.py @@ -1,4 +1,6 @@ from pathlib import Path +from types import ModuleType +from typing import Any from pydantic import Field, field_validator @@ -26,7 +28,10 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): safe_file_paths: list[Path] = Field(default_factory=list) @field_validator("file_path", "file_paths", mode="before") - def validate_file_path(cls, v, info): # noqa: N805 + @classmethod + def validate_file_path( + cls, v: Path | list[Path] | str | list[str] | None, info: Any + ) -> Path | list[Path] | str | list[str] | None: """Validate that at least one of file_path or file_paths is provided.""" # Single check if both are None, O(1) instead of nested conditions if ( @@ -69,7 +74,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): return [self.convert_to_path(path) for path in path_list] - def validate_content(self): + def validate_content(self) -> None: """Validate the paths.""" for path in self.safe_file_paths: if not path.exists(): @@ -86,7 +91,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): color="red", ) - def model_post_init(self, _) -> None: + def model_post_init(self, _: Any) -> None: if self.file_path: self._logger.log( "warning", @@ -128,12 +133,12 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): """Convert a path to a Path object.""" return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path - def _import_dependencies(self): + def _import_dependencies(self) -> ModuleType: """Dynamically import dependencies.""" try: - import pandas as pd # type: ignore[import-untyped,import-not-found] + import pandas as pd # type: ignore[import-untyped] - return pd + return pd # type: ignore[no-any-return] except ImportError as e: missing_package = str(e).split()[-1] raise ImportError( @@ -159,6 +164,20 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): self.chunks.extend(new_chunks) self._save_documents() + async def aadd(self) -> None: + """Add Excel file content asynchronously.""" + content_str = "" + for value in self.content.values(): + if isinstance(value, dict): + for sheet_value in value.values(): + content_str += str(sheet_value) + "\n" + else: + content_str += str(value) + "\n" + + new_chunks = self._chunk_text(content_str) + self.chunks.extend(new_chunks) + await self._asave_documents() + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ diff --git a/lib/crewai/src/crewai/knowledge/source/json_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/json_knowledge_source.py index 0e5c847e2..ac527af2d 100644 --- a/lib/crewai/src/crewai/knowledge/source/json_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/json_knowledge_source.py @@ -44,6 +44,15 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource): self.chunks.extend(new_chunks) self._save_documents() + async def aadd(self) -> None: + """Add JSON file content asynchronously.""" + content_str = ( + str(self.content) if isinstance(self.content, dict) else self.content + ) + new_chunks = self._chunk_text(content_str) + self.chunks.extend(new_chunks) + await self._asave_documents() + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ diff --git a/lib/crewai/src/crewai/knowledge/source/pdf_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/pdf_knowledge_source.py index 7fa663b92..8af860875 100644 --- a/lib/crewai/src/crewai/knowledge/source/pdf_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/pdf_knowledge_source.py @@ -1,4 +1,5 @@ from pathlib import Path +from types import ModuleType from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource @@ -23,7 +24,7 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource): content[path] = text return content - def _import_pdfplumber(self): + def _import_pdfplumber(self) -> ModuleType: """Dynamically import pdfplumber.""" try: import pdfplumber @@ -44,6 +45,13 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource): self.chunks.extend(new_chunks) self._save_documents() + async def aadd(self) -> None: + """Add PDF file content asynchronously.""" + for text in self.content.values(): + new_chunks = self._chunk_text(text) + self.chunks.extend(new_chunks) + await self._asave_documents() + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ diff --git a/lib/crewai/src/crewai/knowledge/source/string_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/string_knowledge_source.py index 97473d9d3..b1165c2d1 100644 --- a/lib/crewai/src/crewai/knowledge/source/string_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/string_knowledge_source.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import Field from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource @@ -9,11 +11,11 @@ class StringKnowledgeSource(BaseKnowledgeSource): content: str = Field(...) collection_name: str | None = Field(default=None) - def model_post_init(self, _): + def model_post_init(self, _: Any) -> None: """Post-initialization method to validate content.""" self.validate_content() - def validate_content(self): + def validate_content(self) -> None: """Validate string content.""" if not isinstance(self.content, str): raise ValueError("StringKnowledgeSource only accepts string content") @@ -24,6 +26,12 @@ class StringKnowledgeSource(BaseKnowledgeSource): self.chunks.extend(new_chunks) self._save_documents() + async def aadd(self) -> None: + """Add string content asynchronously.""" + new_chunks = self._chunk_text(self.content) + self.chunks.extend(new_chunks) + await self._asave_documents() + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ diff --git a/lib/crewai/src/crewai/knowledge/source/text_file_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/text_file_knowledge_source.py index 93a3e2849..00265743d 100644 --- a/lib/crewai/src/crewai/knowledge/source/text_file_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/text_file_knowledge_source.py @@ -25,6 +25,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource): self.chunks.extend(new_chunks) self._save_documents() + async def aadd(self) -> None: + """Add text file content asynchronously.""" + for text in self.content.values(): + new_chunks = self._chunk_text(text) + self.chunks.extend(new_chunks) + await self._asave_documents() + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ diff --git a/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py b/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py index 044837a07..e8a2054f7 100644 --- a/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py +++ b/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py @@ -21,10 +21,28 @@ class BaseKnowledgeStorage(ABC): ) -> list[SearchResult]: """Search for documents in the knowledge base.""" + @abstractmethod + async def asearch( + self, + query: list[str], + limit: int = 5, + metadata_filter: dict[str, Any] | None = None, + score_threshold: float = 0.6, + ) -> list[SearchResult]: + """Search for documents in the knowledge base asynchronously.""" + @abstractmethod def save(self, documents: list[str]) -> None: """Save documents to the knowledge base.""" + @abstractmethod + async def asave(self, documents: list[str]) -> None: + """Save documents to the knowledge base asynchronously.""" + @abstractmethod def reset(self) -> None: """Reset the knowledge base.""" + + @abstractmethod + async def areset(self) -> None: + """Reset the knowledge base asynchronously.""" diff --git a/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py b/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py index 7eed0e0de..055763f7f 100644 --- a/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py +++ b/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py @@ -25,8 +25,8 @@ class KnowledgeStorage(BaseKnowledgeStorage): def __init__( self, embedder: ProviderSpec - | BaseEmbeddingsProvider - | type[BaseEmbeddingsProvider] + | BaseEmbeddingsProvider[Any] + | type[BaseEmbeddingsProvider[Any]] | None = None, collection_name: str | None = None, ) -> None: @@ -127,3 +127,96 @@ class KnowledgeStorage(BaseKnowledgeStorage): ) from e Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red") raise + + async def asearch( + self, + query: list[str], + limit: int = 5, + metadata_filter: dict[str, Any] | None = None, + score_threshold: float = 0.6, + ) -> list[SearchResult]: + """Search for documents in the knowledge base asynchronously. + + Args: + query: List of query strings. + limit: Maximum number of results to return. + metadata_filter: Optional metadata filter for the search. + score_threshold: Minimum similarity score for results. + + Returns: + List of search results. + """ + try: + if not query: + raise ValueError("Query cannot be empty") + + client = self._get_client() + collection_name = ( + f"knowledge_{self.collection_name}" + if self.collection_name + else "knowledge" + ) + query_text = " ".join(query) if len(query) > 1 else query[0] + + return await client.asearch( + collection_name=collection_name, + query=query_text, + limit=limit, + metadata_filter=metadata_filter, + score_threshold=score_threshold, + ) + except Exception as e: + logging.error( + f"Error during knowledge search: {e!s}\n{traceback.format_exc()}" + ) + return [] + + async def asave(self, documents: list[str]) -> None: + """Save documents to the knowledge base asynchronously. + + Args: + documents: List of document strings to save. + """ + try: + client = self._get_client() + collection_name = ( + f"knowledge_{self.collection_name}" + if self.collection_name + else "knowledge" + ) + await client.aget_or_create_collection(collection_name=collection_name) + + rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents] + + await client.aadd_documents( + collection_name=collection_name, documents=rag_documents + ) + except Exception as e: + if "dimension mismatch" in str(e).lower(): + Logger(verbose=True).log( + "error", + "Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`", + "red", + ) + raise ValueError( + "Embedding dimension mismatch. Make sure you're using the same embedding model " + "across all operations with this collection." + "Try resetting the collection using `crewai reset-memories -a`" + ) from e + Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red") + raise + + async def areset(self) -> None: + """Reset the knowledge base asynchronously.""" + try: + client = self._get_client() + collection_name = ( + f"knowledge_{self.collection_name}" + if self.collection_name + else "knowledge" + ) + await client.adelete_collection(collection_name=collection_name) + except Exception as e: + logging.error( + f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}" + ) diff --git a/lib/crewai/tests/knowledge/test_async_knowledge.py b/lib/crewai/tests/knowledge/test_async_knowledge.py new file mode 100644 index 000000000..c243b3ce4 --- /dev/null +++ b/lib/crewai/tests/knowledge/test_async_knowledge.py @@ -0,0 +1,212 @@ +"""Tests for async knowledge operations.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from crewai.knowledge.knowledge import Knowledge +from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource +from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage + + +class TestAsyncKnowledgeStorage: + """Tests for async KnowledgeStorage operations.""" + + @pytest.mark.asyncio + async def test_asearch_returns_results(self): + """Test that asearch returns search results.""" + mock_client = MagicMock() + mock_client.asearch = AsyncMock( + return_value=[{"content": "test result", "score": 0.9}] + ) + + storage = KnowledgeStorage(collection_name="test_collection") + storage._client = mock_client + + results = await storage.asearch(["test query"]) + + assert len(results) == 1 + assert results[0]["content"] == "test result" + mock_client.asearch.assert_called_once() + + @pytest.mark.asyncio + async def test_asearch_empty_query_raises_error(self): + """Test that asearch handles empty query.""" + storage = KnowledgeStorage(collection_name="test_collection") + + # Empty query should not raise but return empty results due to error handling + results = await storage.asearch([]) + assert results == [] + + @pytest.mark.asyncio + async def test_asave_calls_client_methods(self): + """Test that asave calls the correct client methods.""" + mock_client = MagicMock() + mock_client.aget_or_create_collection = AsyncMock() + mock_client.aadd_documents = AsyncMock() + + storage = KnowledgeStorage(collection_name="test_collection") + storage._client = mock_client + + await storage.asave(["document 1", "document 2"]) + + mock_client.aget_or_create_collection.assert_called_once_with( + collection_name="knowledge_test_collection" + ) + mock_client.aadd_documents.assert_called_once() + + @pytest.mark.asyncio + async def test_areset_calls_client_delete(self): + """Test that areset calls delete_collection on the client.""" + mock_client = MagicMock() + mock_client.adelete_collection = AsyncMock() + + storage = KnowledgeStorage(collection_name="test_collection") + storage._client = mock_client + + await storage.areset() + + mock_client.adelete_collection.assert_called_once_with( + collection_name="knowledge_test_collection" + ) + + +class TestAsyncKnowledge: + """Tests for async Knowledge operations.""" + + @pytest.mark.asyncio + async def test_aquery_calls_storage_asearch(self): + """Test that aquery calls storage.asearch.""" + mock_storage = MagicMock(spec=KnowledgeStorage) + mock_storage.asearch = AsyncMock( + return_value=[{"content": "result", "score": 0.8}] + ) + + knowledge = Knowledge( + collection_name="test", + sources=[], + storage=mock_storage, + ) + + results = await knowledge.aquery(["test query"]) + + assert len(results) == 1 + mock_storage.asearch.assert_called_once_with( + ["test query"], + limit=5, + score_threshold=0.6, + ) + + @pytest.mark.asyncio + async def test_aquery_raises_when_storage_not_initialized(self): + """Test that aquery raises ValueError when storage is None.""" + knowledge = Knowledge( + collection_name="test", + sources=[], + storage=MagicMock(spec=KnowledgeStorage), + ) + knowledge.storage = None + + with pytest.raises(ValueError, match="Storage is not initialized"): + await knowledge.aquery(["test query"]) + + @pytest.mark.asyncio + async def test_aadd_sources_calls_source_aadd(self): + """Test that aadd_sources calls aadd on each source.""" + mock_storage = MagicMock(spec=KnowledgeStorage) + mock_source = MagicMock() + mock_source.aadd = AsyncMock() + + knowledge = Knowledge( + collection_name="test", + sources=[mock_source], + storage=mock_storage, + ) + + await knowledge.aadd_sources() + + mock_source.aadd.assert_called_once() + assert mock_source.storage == mock_storage + + @pytest.mark.asyncio + async def test_areset_calls_storage_areset(self): + """Test that areset calls storage.areset.""" + mock_storage = MagicMock(spec=KnowledgeStorage) + mock_storage.areset = AsyncMock() + + knowledge = Knowledge( + collection_name="test", + sources=[], + storage=mock_storage, + ) + + await knowledge.areset() + + mock_storage.areset.assert_called_once() + + @pytest.mark.asyncio + async def test_areset_raises_when_storage_not_initialized(self): + """Test that areset raises ValueError when storage is None.""" + knowledge = Knowledge( + collection_name="test", + sources=[], + storage=MagicMock(spec=KnowledgeStorage), + ) + knowledge.storage = None + + with pytest.raises(ValueError, match="Storage is not initialized"): + await knowledge.areset() + + +class TestAsyncStringKnowledgeSource: + """Tests for async StringKnowledgeSource operations.""" + + @pytest.mark.asyncio + async def test_aadd_saves_documents_asynchronously(self): + """Test that aadd chunks and saves documents asynchronously.""" + mock_storage = MagicMock(spec=KnowledgeStorage) + mock_storage.asave = AsyncMock() + + source = StringKnowledgeSource(content="Test content for async processing") + source.storage = mock_storage + + await source.aadd() + + mock_storage.asave.assert_called_once() + assert len(source.chunks) > 0 + + @pytest.mark.asyncio + async def test_aadd_raises_without_storage(self): + """Test that aadd raises ValueError when storage is not set.""" + source = StringKnowledgeSource(content="Test content") + source.storage = None + + with pytest.raises(ValueError, match="No storage found"): + await source.aadd() + + +class TestAsyncBaseKnowledgeSource: + """Tests for async _asave_documents method.""" + + @pytest.mark.asyncio + async def test_asave_documents_calls_storage_asave(self): + """Test that _asave_documents calls storage.asave.""" + mock_storage = MagicMock(spec=KnowledgeStorage) + mock_storage.asave = AsyncMock() + + source = StringKnowledgeSource(content="Test") + source.storage = mock_storage + source.chunks = ["chunk1", "chunk2"] + + await source._asave_documents() + + mock_storage.asave.assert_called_once_with(["chunk1", "chunk2"]) + + @pytest.mark.asyncio + async def test_asave_documents_raises_without_storage(self): + """Test that _asave_documents raises ValueError when storage is None.""" + source = StringKnowledgeSource(content="Test") + source.storage = None + + with pytest.raises(ValueError, match="No storage found"): + await source._asave_documents() \ No newline at end of file