mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 07:38:14 +00:00
feat: async knowledge support; add tests
This commit is contained in:
@@ -32,8 +32,8 @@ class Knowledge(BaseModel):
|
|||||||
sources: list[BaseKnowledgeSource],
|
sources: list[BaseKnowledgeSource],
|
||||||
embedder: EmbedderConfig | None = None,
|
embedder: EmbedderConfig | None = None,
|
||||||
storage: KnowledgeStorage | None = None,
|
storage: KnowledgeStorage | None = None,
|
||||||
**data,
|
**data: object,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
if storage:
|
if storage:
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
@@ -75,3 +75,44 @@ class Knowledge(BaseModel):
|
|||||||
self.storage.reset()
|
self.storage.reset()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Storage is not initialized.")
|
raise ValueError("Storage is not initialized.")
|
||||||
|
|
||||||
|
async def aquery(
|
||||||
|
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Query across all knowledge sources asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: List of query strings.
|
||||||
|
results_limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The top results matching the query.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If storage is not initialized.
|
||||||
|
"""
|
||||||
|
if self.storage is None:
|
||||||
|
raise ValueError("Storage is not initialized.")
|
||||||
|
|
||||||
|
return await self.storage.asearch(
|
||||||
|
query,
|
||||||
|
limit=results_limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aadd_sources(self) -> None:
|
||||||
|
"""Add all knowledge sources to storage asynchronously."""
|
||||||
|
try:
|
||||||
|
for source in self.sources:
|
||||||
|
source.storage = self.storage
|
||||||
|
await source.aadd()
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def areset(self) -> None:
|
||||||
|
"""Reset the knowledge base asynchronously."""
|
||||||
|
if self.storage:
|
||||||
|
await self.storage.areset()
|
||||||
|
else:
|
||||||
|
raise ValueError("Storage is not initialized.")
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
@@ -25,7 +26,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||||
|
|
||||||
@field_validator("file_path", "file_paths", mode="before")
|
@field_validator("file_path", "file_paths", mode="before")
|
||||||
def validate_file_path(cls, v, info): # noqa: N805
|
@classmethod
|
||||||
|
def validate_file_path(
|
||||||
|
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||||
|
) -> Path | list[Path] | str | list[str] | None:
|
||||||
"""Validate that at least one of file_path or file_paths is provided."""
|
"""Validate that at least one of file_path or file_paths is provided."""
|
||||||
# Single check if both are None, O(1) instead of nested conditions
|
# Single check if both are None, O(1) instead of nested conditions
|
||||||
if (
|
if (
|
||||||
@@ -38,7 +42,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
raise ValueError("Either file_path or file_paths must be provided")
|
raise ValueError("Either file_path or file_paths must be provided")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def model_post_init(self, _):
|
def model_post_init(self, _: Any) -> None:
|
||||||
"""Post-initialization method to load content."""
|
"""Post-initialization method to load content."""
|
||||||
self.safe_file_paths = self._process_file_paths()
|
self.safe_file_paths = self._process_file_paths()
|
||||||
self.validate_content()
|
self.validate_content()
|
||||||
@@ -48,7 +52,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
def load_content(self) -> dict[Path, str]:
|
def load_content(self) -> dict[Path, str]:
|
||||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||||
|
|
||||||
def validate_content(self):
|
def validate_content(self) -> None:
|
||||||
"""Validate the paths."""
|
"""Validate the paths."""
|
||||||
for path in self.safe_file_paths:
|
for path in self.safe_file_paths:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
@@ -65,13 +69,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _save_documents(self):
|
def _save_documents(self) -> None:
|
||||||
"""Save the documents to the storage."""
|
"""Save the documents to the storage."""
|
||||||
if self.storage:
|
if self.storage:
|
||||||
self.storage.save(self.chunks)
|
self.storage.save(self.chunks)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No storage found to save documents.")
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|
||||||
|
async def _asave_documents(self) -> None:
|
||||||
|
"""Save the documents to the storage asynchronously."""
|
||||||
|
if self.storage:
|
||||||
|
await self.storage.asave(self.chunks)
|
||||||
|
else:
|
||||||
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|
||||||
def convert_to_path(self, path: Path | str) -> Path:
|
def convert_to_path(self, path: Path | str) -> Path:
|
||||||
"""Convert a path to a Path object."""
|
"""Convert a path to a Path object."""
|
||||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||||
|
|||||||
@@ -39,12 +39,32 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
|||||||
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
|
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _save_documents(self):
|
def _save_documents(self) -> None:
|
||||||
"""
|
"""Save the documents to the storage.
|
||||||
Save the documents to the storage.
|
|
||||||
This method should be called after the chunks and embeddings are generated.
|
This method should be called after the chunks and embeddings are generated.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no storage is configured.
|
||||||
"""
|
"""
|
||||||
if self.storage:
|
if self.storage:
|
||||||
self.storage.save(self.chunks)
|
self.storage.save(self.chunks)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No storage found to save documents.")
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Process content, chunk it, compute embeddings, and save them asynchronously."""
|
||||||
|
|
||||||
|
async def _asave_documents(self) -> None:
|
||||||
|
"""Save the documents to the storage asynchronously.
|
||||||
|
|
||||||
|
This method should be called after the chunks and embeddings are generated.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no storage is configured.
|
||||||
|
"""
|
||||||
|
if self.storage:
|
||||||
|
await self.storage.asave(self.chunks)
|
||||||
|
else:
|
||||||
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|||||||
@@ -2,27 +2,24 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from docling.datamodel.base_models import ( # type: ignore[import-not-found]
|
from docling.datamodel.base_models import InputFormat
|
||||||
InputFormat,
|
from docling.document_converter import DocumentConverter
|
||||||
)
|
from docling.exceptions import ConversionError
|
||||||
from docling.document_converter import ( # type: ignore[import-not-found]
|
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||||
DocumentConverter,
|
from docling_core.types.doc.document import DoclingDocument
|
||||||
)
|
|
||||||
from docling.exceptions import ConversionError # type: ignore[import-not-found]
|
|
||||||
from docling_core.transforms.chunker.hierarchical_chunker import ( # type: ignore[import-not-found]
|
|
||||||
HierarchicalChunker,
|
|
||||||
)
|
|
||||||
from docling_core.types.doc.document import ( # type: ignore[import-not-found]
|
|
||||||
DoclingDocument,
|
|
||||||
)
|
|
||||||
|
|
||||||
DOCLING_AVAILABLE = True
|
DOCLING_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
DOCLING_AVAILABLE = False
|
DOCLING_AVAILABLE = False
|
||||||
|
# Provide type stubs for when docling is not available
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from docling.document_converter import DocumentConverter
|
||||||
|
from docling_core.types.doc.document import DoclingDocument
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@@ -32,11 +29,13 @@ from crewai.utilities.logger import Logger
|
|||||||
|
|
||||||
|
|
||||||
class CrewDoclingSource(BaseKnowledgeSource):
|
class CrewDoclingSource(BaseKnowledgeSource):
|
||||||
"""Default Source class for converting documents to markdown or json
|
"""Default Source class for converting documents to markdown or json.
|
||||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
|
|
||||||
|
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without
|
||||||
|
any additional dependencies and follows the docling package as the source of truth.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
if not DOCLING_AVAILABLE:
|
if not DOCLING_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The docling package is required to use CrewDoclingSource. "
|
"The docling package is required to use CrewDoclingSource. "
|
||||||
@@ -66,7 +65,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_post_init(self, _) -> None:
|
def model_post_init(self, _: Any) -> None:
|
||||||
if self.file_path:
|
if self.file_path:
|
||||||
self._logger.log(
|
self._logger.log(
|
||||||
"warning",
|
"warning",
|
||||||
@@ -99,6 +98,15 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
|||||||
self.chunks.extend(list(new_chunks_iterable))
|
self.chunks.extend(list(new_chunks_iterable))
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add docling content asynchronously."""
|
||||||
|
if self.content is None:
|
||||||
|
return
|
||||||
|
for doc in self.content:
|
||||||
|
new_chunks_iterable = self._chunk_doc(doc)
|
||||||
|
self.chunks.extend(list(new_chunks_iterable))
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _convert_source_to_docling_documents(self) -> list[DoclingDocument]:
|
def _convert_source_to_docling_documents(self) -> list[DoclingDocument]:
|
||||||
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||||
return [result.document for result in conv_results_iter]
|
return [result.document for result in conv_results_iter]
|
||||||
|
|||||||
@@ -31,6 +31,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add CSV file content asynchronously."""
|
||||||
|
content_str = (
|
||||||
|
str(self.content) if isinstance(self.content, dict) else self.content
|
||||||
|
)
|
||||||
|
new_chunks = self._chunk_text(content_str)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
@@ -26,7 +28,10 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||||
|
|
||||||
@field_validator("file_path", "file_paths", mode="before")
|
@field_validator("file_path", "file_paths", mode="before")
|
||||||
def validate_file_path(cls, v, info): # noqa: N805
|
@classmethod
|
||||||
|
def validate_file_path(
|
||||||
|
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||||
|
) -> Path | list[Path] | str | list[str] | None:
|
||||||
"""Validate that at least one of file_path or file_paths is provided."""
|
"""Validate that at least one of file_path or file_paths is provided."""
|
||||||
# Single check if both are None, O(1) instead of nested conditions
|
# Single check if both are None, O(1) instead of nested conditions
|
||||||
if (
|
if (
|
||||||
@@ -69,7 +74,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
|
|
||||||
return [self.convert_to_path(path) for path in path_list]
|
return [self.convert_to_path(path) for path in path_list]
|
||||||
|
|
||||||
def validate_content(self):
|
def validate_content(self) -> None:
|
||||||
"""Validate the paths."""
|
"""Validate the paths."""
|
||||||
for path in self.safe_file_paths:
|
for path in self.safe_file_paths:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
@@ -86,7 +91,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_post_init(self, _) -> None:
|
def model_post_init(self, _: Any) -> None:
|
||||||
if self.file_path:
|
if self.file_path:
|
||||||
self._logger.log(
|
self._logger.log(
|
||||||
"warning",
|
"warning",
|
||||||
@@ -128,12 +133,12 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
"""Convert a path to a Path object."""
|
"""Convert a path to a Path object."""
|
||||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||||
|
|
||||||
def _import_dependencies(self):
|
def _import_dependencies(self) -> ModuleType:
|
||||||
"""Dynamically import dependencies."""
|
"""Dynamically import dependencies."""
|
||||||
try:
|
try:
|
||||||
import pandas as pd # type: ignore[import-untyped,import-not-found]
|
import pandas as pd # type: ignore[import-untyped]
|
||||||
|
|
||||||
return pd
|
return pd # type: ignore[no-any-return]
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
missing_package = str(e).split()[-1]
|
missing_package = str(e).split()[-1]
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -159,6 +164,20 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add Excel file content asynchronously."""
|
||||||
|
content_str = ""
|
||||||
|
for value in self.content.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for sheet_value in value.values():
|
||||||
|
content_str += str(sheet_value) + "\n"
|
||||||
|
else:
|
||||||
|
content_str += str(value) + "\n"
|
||||||
|
|
||||||
|
new_chunks = self._chunk_text(content_str)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -44,6 +44,15 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add JSON file content asynchronously."""
|
||||||
|
content_str = (
|
||||||
|
str(self.content) if isinstance(self.content, dict) else self.content
|
||||||
|
)
|
||||||
|
new_chunks = self._chunk_text(content_str)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||||
|
|
||||||
@@ -23,7 +24,7 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
content[path] = text
|
content[path] = text
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def _import_pdfplumber(self):
|
def _import_pdfplumber(self) -> ModuleType:
|
||||||
"""Dynamically import pdfplumber."""
|
"""Dynamically import pdfplumber."""
|
||||||
try:
|
try:
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
@@ -44,6 +45,13 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add PDF file content asynchronously."""
|
||||||
|
for text in self.content.values():
|
||||||
|
new_chunks = self._chunk_text(text)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
@@ -9,11 +11,11 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
|||||||
content: str = Field(...)
|
content: str = Field(...)
|
||||||
collection_name: str | None = Field(default=None)
|
collection_name: str | None = Field(default=None)
|
||||||
|
|
||||||
def model_post_init(self, _):
|
def model_post_init(self, _: Any) -> None:
|
||||||
"""Post-initialization method to validate content."""
|
"""Post-initialization method to validate content."""
|
||||||
self.validate_content()
|
self.validate_content()
|
||||||
|
|
||||||
def validate_content(self):
|
def validate_content(self) -> None:
|
||||||
"""Validate string content."""
|
"""Validate string content."""
|
||||||
if not isinstance(self.content, str):
|
if not isinstance(self.content, str):
|
||||||
raise ValueError("StringKnowledgeSource only accepts string content")
|
raise ValueError("StringKnowledgeSource only accepts string content")
|
||||||
@@ -24,6 +26,12 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add string content asynchronously."""
|
||||||
|
new_chunks = self._chunk_text(self.content)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -25,6 +25,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add text file content asynchronously."""
|
||||||
|
for text in self.content.values():
|
||||||
|
new_chunks = self._chunk_text(text)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -21,10 +21,28 @@ class BaseKnowledgeStorage(ABC):
|
|||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""Search for documents in the knowledge base."""
|
"""Search for documents in the knowledge base."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: list[str],
|
||||||
|
limit: int = 5,
|
||||||
|
metadata_filter: dict[str, Any] | None = None,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Search for documents in the knowledge base asynchronously."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(self, documents: list[str]) -> None:
|
def save(self, documents: list[str]) -> None:
|
||||||
"""Save documents to the knowledge base."""
|
"""Save documents to the knowledge base."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def asave(self, documents: list[str]) -> None:
|
||||||
|
"""Save documents to the knowledge base asynchronously."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset the knowledge base."""
|
"""Reset the knowledge base."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def areset(self) -> None:
|
||||||
|
"""Reset the knowledge base asynchronously."""
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedder: ProviderSpec
|
embedder: ProviderSpec
|
||||||
| BaseEmbeddingsProvider
|
| BaseEmbeddingsProvider[Any]
|
||||||
| type[BaseEmbeddingsProvider]
|
| type[BaseEmbeddingsProvider[Any]]
|
||||||
| None = None,
|
| None = None,
|
||||||
collection_name: str | None = None,
|
collection_name: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -127,3 +127,96 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
) from e
|
) from e
|
||||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: list[str],
|
||||||
|
limit: int = 5,
|
||||||
|
metadata_filter: dict[str, Any] | None = None,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Search for documents in the knowledge base asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: List of query strings.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
metadata_filter: Optional metadata filter for the search.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not query:
|
||||||
|
raise ValueError("Query cannot be empty")
|
||||||
|
|
||||||
|
client = self._get_client()
|
||||||
|
collection_name = (
|
||||||
|
f"knowledge_{self.collection_name}"
|
||||||
|
if self.collection_name
|
||||||
|
else "knowledge"
|
||||||
|
)
|
||||||
|
query_text = " ".join(query) if len(query) > 1 else query[0]
|
||||||
|
|
||||||
|
return await client.asearch(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query=query_text,
|
||||||
|
limit=limit,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def asave(self, documents: list[str]) -> None:
|
||||||
|
"""Save documents to the knowledge base asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: List of document strings to save.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
collection_name = (
|
||||||
|
f"knowledge_{self.collection_name}"
|
||||||
|
if self.collection_name
|
||||||
|
else "knowledge"
|
||||||
|
)
|
||||||
|
await client.aget_or_create_collection(collection_name=collection_name)
|
||||||
|
|
||||||
|
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||||
|
|
||||||
|
await client.aadd_documents(
|
||||||
|
collection_name=collection_name, documents=rag_documents
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if "dimension mismatch" in str(e).lower():
|
||||||
|
Logger(verbose=True).log(
|
||||||
|
"error",
|
||||||
|
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||||
|
"across all operations with this collection."
|
||||||
|
"Try resetting the collection using `crewai reset-memories -a`"
|
||||||
|
) from e
|
||||||
|
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def areset(self) -> None:
|
||||||
|
"""Reset the knowledge base asynchronously."""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
collection_name = (
|
||||||
|
f"knowledge_{self.collection_name}"
|
||||||
|
if self.collection_name
|
||||||
|
else "knowledge"
|
||||||
|
)
|
||||||
|
await client.adelete_collection(collection_name=collection_name)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
|||||||
212
lib/crewai/tests/knowledge/test_async_knowledge.py
Normal file
212
lib/crewai/tests/knowledge/test_async_knowledge.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""Tests for async knowledge operations."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
|
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||||
|
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncKnowledgeStorage:
|
||||||
|
"""Tests for async KnowledgeStorage operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch_returns_results(self):
|
||||||
|
"""Test that asearch returns search results."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.asearch = AsyncMock(
|
||||||
|
return_value=[{"content": "test result", "score": 0.9}]
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = KnowledgeStorage(collection_name="test_collection")
|
||||||
|
storage._client = mock_client
|
||||||
|
|
||||||
|
results = await storage.asearch(["test query"])
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["content"] == "test result"
|
||||||
|
mock_client.asearch.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch_empty_query_raises_error(self):
|
||||||
|
"""Test that asearch handles empty query."""
|
||||||
|
storage = KnowledgeStorage(collection_name="test_collection")
|
||||||
|
|
||||||
|
# Empty query should not raise but return empty results due to error handling
|
||||||
|
results = await storage.asearch([])
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_calls_client_methods(self):
|
||||||
|
"""Test that asave calls the correct client methods."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.aget_or_create_collection = AsyncMock()
|
||||||
|
mock_client.aadd_documents = AsyncMock()
|
||||||
|
|
||||||
|
storage = KnowledgeStorage(collection_name="test_collection")
|
||||||
|
storage._client = mock_client
|
||||||
|
|
||||||
|
await storage.asave(["document 1", "document 2"])
|
||||||
|
|
||||||
|
mock_client.aget_or_create_collection.assert_called_once_with(
|
||||||
|
collection_name="knowledge_test_collection"
|
||||||
|
)
|
||||||
|
mock_client.aadd_documents.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_areset_calls_client_delete(self):
|
||||||
|
"""Test that areset calls delete_collection on the client."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.adelete_collection = AsyncMock()
|
||||||
|
|
||||||
|
storage = KnowledgeStorage(collection_name="test_collection")
|
||||||
|
storage._client = mock_client
|
||||||
|
|
||||||
|
await storage.areset()
|
||||||
|
|
||||||
|
mock_client.adelete_collection.assert_called_once_with(
|
||||||
|
collection_name="knowledge_test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncKnowledge:
|
||||||
|
"""Tests for async Knowledge operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aquery_calls_storage_asearch(self):
|
||||||
|
"""Test that aquery calls storage.asearch."""
|
||||||
|
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||||
|
mock_storage.asearch = AsyncMock(
|
||||||
|
return_value=[{"content": "result", "score": 0.8}]
|
||||||
|
)
|
||||||
|
|
||||||
|
knowledge = Knowledge(
|
||||||
|
collection_name="test",
|
||||||
|
sources=[],
|
||||||
|
storage=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await knowledge.aquery(["test query"])
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
mock_storage.asearch.assert_called_once_with(
|
||||||
|
["test query"],
|
||||||
|
limit=5,
|
||||||
|
score_threshold=0.6,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aquery_raises_when_storage_not_initialized(self):
|
||||||
|
"""Test that aquery raises ValueError when storage is None."""
|
||||||
|
knowledge = Knowledge(
|
||||||
|
collection_name="test",
|
||||||
|
sources=[],
|
||||||
|
storage=MagicMock(spec=KnowledgeStorage),
|
||||||
|
)
|
||||||
|
knowledge.storage = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||||
|
await knowledge.aquery(["test query"])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadd_sources_calls_source_aadd(self):
|
||||||
|
"""Test that aadd_sources calls aadd on each source."""
|
||||||
|
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||||
|
mock_source = MagicMock()
|
||||||
|
mock_source.aadd = AsyncMock()
|
||||||
|
|
||||||
|
knowledge = Knowledge(
|
||||||
|
collection_name="test",
|
||||||
|
sources=[mock_source],
|
||||||
|
storage=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
await knowledge.aadd_sources()
|
||||||
|
|
||||||
|
mock_source.aadd.assert_called_once()
|
||||||
|
assert mock_source.storage == mock_storage
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_areset_calls_storage_areset(self):
|
||||||
|
"""Test that areset calls storage.areset."""
|
||||||
|
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||||
|
mock_storage.areset = AsyncMock()
|
||||||
|
|
||||||
|
knowledge = Knowledge(
|
||||||
|
collection_name="test",
|
||||||
|
sources=[],
|
||||||
|
storage=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
await knowledge.areset()
|
||||||
|
|
||||||
|
mock_storage.areset.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_areset_raises_when_storage_not_initialized(self):
|
||||||
|
"""Test that areset raises ValueError when storage is None."""
|
||||||
|
knowledge = Knowledge(
|
||||||
|
collection_name="test",
|
||||||
|
sources=[],
|
||||||
|
storage=MagicMock(spec=KnowledgeStorage),
|
||||||
|
)
|
||||||
|
knowledge.storage = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||||
|
await knowledge.areset()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncStringKnowledgeSource:
|
||||||
|
"""Tests for async StringKnowledgeSource operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadd_saves_documents_asynchronously(self):
|
||||||
|
"""Test that aadd chunks and saves documents asynchronously."""
|
||||||
|
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||||
|
mock_storage.asave = AsyncMock()
|
||||||
|
|
||||||
|
source = StringKnowledgeSource(content="Test content for async processing")
|
||||||
|
source.storage = mock_storage
|
||||||
|
|
||||||
|
await source.aadd()
|
||||||
|
|
||||||
|
mock_storage.asave.assert_called_once()
|
||||||
|
assert len(source.chunks) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadd_raises_without_storage(self):
|
||||||
|
"""Test that aadd raises ValueError when storage is not set."""
|
||||||
|
source = StringKnowledgeSource(content="Test content")
|
||||||
|
source.storage = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No storage found"):
|
||||||
|
await source.aadd()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncBaseKnowledgeSource:
|
||||||
|
"""Tests for async _asave_documents method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_documents_calls_storage_asave(self):
|
||||||
|
"""Test that _asave_documents calls storage.asave."""
|
||||||
|
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||||
|
mock_storage.asave = AsyncMock()
|
||||||
|
|
||||||
|
source = StringKnowledgeSource(content="Test")
|
||||||
|
source.storage = mock_storage
|
||||||
|
source.chunks = ["chunk1", "chunk2"]
|
||||||
|
|
||||||
|
await source._asave_documents()
|
||||||
|
|
||||||
|
mock_storage.asave.assert_called_once_with(["chunk1", "chunk2"])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_documents_raises_without_storage(self):
|
||||||
|
"""Test that _asave_documents raises ValueError when storage is None."""
|
||||||
|
source = StringKnowledgeSource(content="Test")
|
||||||
|
source.storage = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No storage found"):
|
||||||
|
await source._asave_documents()
|
||||||
Reference in New Issue
Block a user