mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
feat: async knowledge support (#4023)
* feat: add async support for tools, add async tool tests * chore: improve tool decorator typing * fix: ensure _run backward compat * chore: update docs * chore: make docstrings a little more readable * feat: add async execution support to agent executor * chore: add tests * feat: add aiosqlite dep; regenerate lockfile * feat: add async ops to memory feat; create tests * feat: async knowledge support; add tests * chore: regenerate lockfile
This commit is contained in:
@@ -32,8 +32,8 @@ class Knowledge(BaseModel):
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: EmbedderConfig | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
**data,
|
||||
):
|
||||
**data: object,
|
||||
) -> None:
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
self.storage = storage
|
||||
@@ -75,3 +75,44 @@ class Knowledge(BaseModel):
|
||||
self.storage.reset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
async def aquery(
|
||||
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||
) -> list[SearchResult]:
|
||||
"""Query across all knowledge sources asynchronously.
|
||||
|
||||
Args:
|
||||
query: List of query strings.
|
||||
results_limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
The top results matching the query.
|
||||
|
||||
Raises:
|
||||
ValueError: If storage is not initialized.
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
return await self.storage.asearch(
|
||||
query,
|
||||
limit=results_limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
async def aadd_sources(self) -> None:
|
||||
"""Add all knowledge sources to storage asynchronously."""
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
await source.aadd()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
if self.storage:
|
||||
await self.storage.areset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -25,7 +26,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info): # noqa: N805
|
||||
@classmethod
|
||||
def validate_file_path(
|
||||
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||
) -> Path | list[Path] | str | list[str] | None:
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -38,7 +42,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
return v
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
"""Post-initialization method to load content."""
|
||||
self.safe_file_paths = self._process_file_paths()
|
||||
self.validate_content()
|
||||
@@ -48,7 +52,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -65,13 +69,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
color="red",
|
||||
)
|
||||
|
||||
def _save_documents(self):
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage."""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
async def _asave_documents(self) -> None:
|
||||
"""Save the documents to the storage asynchronously."""
|
||||
if self.storage:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
@@ -39,12 +39,32 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
|
||||
]
|
||||
|
||||
def _save_documents(self):
|
||||
"""
|
||||
Save the documents to the storage.
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage.
|
||||
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
@abstractmethod
|
||||
async def aadd(self) -> None:
|
||||
"""Process content, chunk it, compute embeddings, and save them asynchronously."""
|
||||
|
||||
async def _asave_documents(self) -> None:
|
||||
"""Save the documents to the storage asynchronously.
|
||||
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
@@ -2,27 +2,24 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
try:
|
||||
from docling.datamodel.base_models import ( # type: ignore[import-not-found]
|
||||
InputFormat,
|
||||
)
|
||||
from docling.document_converter import ( # type: ignore[import-not-found]
|
||||
DocumentConverter,
|
||||
)
|
||||
from docling.exceptions import ConversionError # type: ignore[import-not-found]
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import ( # type: ignore[import-not-found]
|
||||
HierarchicalChunker,
|
||||
)
|
||||
from docling_core.types.doc.document import ( # type: ignore[import-not-found]
|
||||
DoclingDocument,
|
||||
)
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.exceptions import ConversionError
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
DOCLING_AVAILABLE = True
|
||||
except ImportError:
|
||||
DOCLING_AVAILABLE = False
|
||||
# Provide type stubs for when docling is not available
|
||||
if TYPE_CHECKING:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -32,11 +29,13 @@ from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
class CrewDoclingSource(BaseKnowledgeSource):
|
||||
"""Default Source class for converting documents to markdown or json
|
||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
|
||||
"""Default Source class for converting documents to markdown or json.
|
||||
|
||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without
|
||||
any additional dependencies and follows the docling package as the source of truth.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
if not DOCLING_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The docling package is required to use CrewDoclingSource. "
|
||||
@@ -66,7 +65,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
)
|
||||
)
|
||||
|
||||
def model_post_init(self, _) -> None:
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
if self.file_path:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -99,6 +98,15 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add docling content asynchronously."""
|
||||
if self.content is None:
|
||||
return
|
||||
for doc in self.content:
|
||||
new_chunks_iterable = self._chunk_doc(doc)
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
await self._asave_documents()
|
||||
|
||||
def _convert_source_to_docling_documents(self) -> list[DoclingDocument]:
|
||||
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||
return [result.document for result in conv_results_iter]
|
||||
|
||||
@@ -31,6 +31,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add CSV file content asynchronously."""
|
||||
content_str = (
|
||||
str(self.content) if isinstance(self.content, dict) else self.content
|
||||
)
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -26,7 +28,10 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info): # noqa: N805
|
||||
@classmethod
|
||||
def validate_file_path(
|
||||
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||
) -> Path | list[Path] | str | list[str] | None:
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -69,7 +74,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
|
||||
return [self.convert_to_path(path) for path in path_list]
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -86,7 +91,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
color="red",
|
||||
)
|
||||
|
||||
def model_post_init(self, _) -> None:
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
if self.file_path:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -128,12 +133,12 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
def _import_dependencies(self):
|
||||
def _import_dependencies(self) -> ModuleType:
|
||||
"""Dynamically import dependencies."""
|
||||
try:
|
||||
import pandas as pd # type: ignore[import-untyped,import-not-found]
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
|
||||
return pd
|
||||
return pd # type: ignore[no-any-return]
|
||||
except ImportError as e:
|
||||
missing_package = str(e).split()[-1]
|
||||
raise ImportError(
|
||||
@@ -159,6 +164,20 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add Excel file content asynchronously."""
|
||||
content_str = ""
|
||||
for value in self.content.values():
|
||||
if isinstance(value, dict):
|
||||
for sheet_value in value.values():
|
||||
content_str += str(sheet_value) + "\n"
|
||||
else:
|
||||
content_str += str(value) + "\n"
|
||||
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -44,6 +44,15 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add JSON file content asynchronously."""
|
||||
content_str = (
|
||||
str(self.content) if isinstance(self.content, dict) else self.content
|
||||
)
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -23,7 +24,7 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
content[path] = text
|
||||
return content
|
||||
|
||||
def _import_pdfplumber(self):
|
||||
def _import_pdfplumber(self) -> ModuleType:
|
||||
"""Dynamically import pdfplumber."""
|
||||
try:
|
||||
import pdfplumber
|
||||
@@ -44,6 +45,13 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add PDF file content asynchronously."""
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
@@ -9,11 +11,11 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
content: str = Field(...)
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
"""Post-initialization method to validate content."""
|
||||
self.validate_content()
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate string content."""
|
||||
if not isinstance(self.content, str):
|
||||
raise ValueError("StringKnowledgeSource only accepts string content")
|
||||
@@ -24,6 +26,12 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add string content asynchronously."""
|
||||
new_chunks = self._chunk_text(self.content)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -25,6 +25,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add text file content asynchronously."""
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -21,10 +21,28 @@ class BaseKnowledgeStorage(ABC):
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def asearch(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def asave(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
|
||||
@@ -25,8 +25,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
def __init__(
|
||||
self,
|
||||
embedder: ProviderSpec
|
||||
| BaseEmbeddingsProvider
|
||||
| type[BaseEmbeddingsProvider]
|
||||
| BaseEmbeddingsProvider[Any]
|
||||
| type[BaseEmbeddingsProvider[Any]]
|
||||
| None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
@@ -127,3 +127,96 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
) from e
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base asynchronously.
|
||||
|
||||
Args:
|
||||
query: List of query strings.
|
||||
limit: Maximum number of results to return.
|
||||
metadata_filter: Optional metadata filter for the search.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of search results.
|
||||
"""
|
||||
try:
|
||||
if not query:
|
||||
raise ValueError("Query cannot be empty")
|
||||
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
query_text = " ".join(query) if len(query) > 1 else query[0]
|
||||
|
||||
return await client.asearch(
|
||||
collection_name=collection_name,
|
||||
query=query_text,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
async def asave(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base asynchronously.
|
||||
|
||||
Args:
|
||||
documents: List of document strings to save.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
await client.aget_or_create_collection(collection_name=collection_name)
|
||||
|
||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
except Exception as e:
|
||||
if "dimension mismatch" in str(e).lower():
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||
"red",
|
||||
)
|
||||
raise ValueError(
|
||||
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||
"across all operations with this collection."
|
||||
"Try resetting the collection using `crewai reset-memories -a`"
|
||||
) from e
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
await client.adelete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
212
lib/crewai/tests/knowledge/test_async_knowledge.py
Normal file
212
lib/crewai/tests/knowledge/test_async_knowledge.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""Tests for async knowledge operations."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
|
||||
|
||||
class TestAsyncKnowledgeStorage:
|
||||
"""Tests for async KnowledgeStorage operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_returns_results(self):
|
||||
"""Test that asearch returns search results."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.asearch = AsyncMock(
|
||||
return_value=[{"content": "test result", "score": 0.9}]
|
||||
)
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_collection")
|
||||
storage._client = mock_client
|
||||
|
||||
results = await storage.asearch(["test query"])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "test result"
|
||||
mock_client.asearch.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_empty_query_raises_error(self):
|
||||
"""Test that asearch handles empty query."""
|
||||
storage = KnowledgeStorage(collection_name="test_collection")
|
||||
|
||||
# Empty query should not raise but return empty results due to error handling
|
||||
results = await storage.asearch([])
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_calls_client_methods(self):
|
||||
"""Test that asave calls the correct client methods."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.aget_or_create_collection = AsyncMock()
|
||||
mock_client.aadd_documents = AsyncMock()
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_collection")
|
||||
storage._client = mock_client
|
||||
|
||||
await storage.asave(["document 1", "document 2"])
|
||||
|
||||
mock_client.aget_or_create_collection.assert_called_once_with(
|
||||
collection_name="knowledge_test_collection"
|
||||
)
|
||||
mock_client.aadd_documents.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_calls_client_delete(self):
|
||||
"""Test that areset calls delete_collection on the client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.adelete_collection = AsyncMock()
|
||||
|
||||
storage = KnowledgeStorage(collection_name="test_collection")
|
||||
storage._client = mock_client
|
||||
|
||||
await storage.areset()
|
||||
|
||||
mock_client.adelete_collection.assert_called_once_with(
|
||||
collection_name="knowledge_test_collection"
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncKnowledge:
|
||||
"""Tests for async Knowledge operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aquery_calls_storage_asearch(self):
|
||||
"""Test that aquery calls storage.asearch."""
|
||||
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||
mock_storage.asearch = AsyncMock(
|
||||
return_value=[{"content": "result", "score": 0.8}]
|
||||
)
|
||||
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[],
|
||||
storage=mock_storage,
|
||||
)
|
||||
|
||||
results = await knowledge.aquery(["test query"])
|
||||
|
||||
assert len(results) == 1
|
||||
mock_storage.asearch.assert_called_once_with(
|
||||
["test query"],
|
||||
limit=5,
|
||||
score_threshold=0.6,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aquery_raises_when_storage_not_initialized(self):
|
||||
"""Test that aquery raises ValueError when storage is None."""
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[],
|
||||
storage=MagicMock(spec=KnowledgeStorage),
|
||||
)
|
||||
knowledge.storage = None
|
||||
|
||||
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||
await knowledge.aquery(["test query"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_sources_calls_source_aadd(self):
|
||||
"""Test that aadd_sources calls aadd on each source."""
|
||||
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||
mock_source = MagicMock()
|
||||
mock_source.aadd = AsyncMock()
|
||||
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[mock_source],
|
||||
storage=mock_storage,
|
||||
)
|
||||
|
||||
await knowledge.aadd_sources()
|
||||
|
||||
mock_source.aadd.assert_called_once()
|
||||
assert mock_source.storage == mock_storage
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_calls_storage_areset(self):
|
||||
"""Test that areset calls storage.areset."""
|
||||
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||
mock_storage.areset = AsyncMock()
|
||||
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[],
|
||||
storage=mock_storage,
|
||||
)
|
||||
|
||||
await knowledge.areset()
|
||||
|
||||
mock_storage.areset.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_raises_when_storage_not_initialized(self):
|
||||
"""Test that areset raises ValueError when storage is None."""
|
||||
knowledge = Knowledge(
|
||||
collection_name="test",
|
||||
sources=[],
|
||||
storage=MagicMock(spec=KnowledgeStorage),
|
||||
)
|
||||
knowledge.storage = None
|
||||
|
||||
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||
await knowledge.areset()
|
||||
|
||||
|
||||
class TestAsyncStringKnowledgeSource:
|
||||
"""Tests for async StringKnowledgeSource operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_saves_documents_asynchronously(self):
|
||||
"""Test that aadd chunks and saves documents asynchronously."""
|
||||
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||
mock_storage.asave = AsyncMock()
|
||||
|
||||
source = StringKnowledgeSource(content="Test content for async processing")
|
||||
source.storage = mock_storage
|
||||
|
||||
await source.aadd()
|
||||
|
||||
mock_storage.asave.assert_called_once()
|
||||
assert len(source.chunks) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_raises_without_storage(self):
|
||||
"""Test that aadd raises ValueError when storage is not set."""
|
||||
source = StringKnowledgeSource(content="Test content")
|
||||
source.storage = None
|
||||
|
||||
with pytest.raises(ValueError, match="No storage found"):
|
||||
await source.aadd()
|
||||
|
||||
|
||||
class TestAsyncBaseKnowledgeSource:
|
||||
"""Tests for async _asave_documents method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_documents_calls_storage_asave(self):
|
||||
"""Test that _asave_documents calls storage.asave."""
|
||||
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||
mock_storage.asave = AsyncMock()
|
||||
|
||||
source = StringKnowledgeSource(content="Test")
|
||||
source.storage = mock_storage
|
||||
source.chunks = ["chunk1", "chunk2"]
|
||||
|
||||
await source._asave_documents()
|
||||
|
||||
mock_storage.asave.assert_called_once_with(["chunk1", "chunk2"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_documents_raises_without_storage(self):
|
||||
"""Test that _asave_documents raises ValueError when storage is None."""
|
||||
source = StringKnowledgeSource(content="Test")
|
||||
source.storage = None
|
||||
|
||||
with pytest.raises(ValueError, match="No storage found"):
|
||||
await source._asave_documents()
|
||||
Reference in New Issue
Block a user