mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Adds RAG feature (#406)
* feat: initialize rag * refactor: using cosine distance metric for chromadb * feat: use RecursiveCharacterTextSplitter as chunker strategy * feat: support chucker and loader per data_type * feat: adding JSON loader * feat: adding CSVLoader * feat: adding loader for DOCX files * feat: add loader for MDX files * feat: add loader for XML files * feat: add loader for parser Webpage * feat: support to load files from an entire directory * feat: support to auto-load the loaders for additional DataType * feat: add chuckers for some specific data type - Each chunker uses separators specific to its content type * feat: prevent document duplication and centralize content management - Implement document deduplication logic in RAG * Check for existing documents by source reference * Compare doc IDs to detect content changes * Automatically replace outdated content while preventing duplicates - Centralize common functionality for better maintainability * Create SourceContent class to handle URLs, files, and text uniformly * Extract shared utilities (compute_sha256) to misc.py * Standardize doc ID generation across all loaders - Improve RAG system architecture * All loaders now inherit consistent patterns from centralized BaseLoader * Better separation of concerns with dedicated content management classes * Standardized LoaderResult structure across all loader implementations * chore: split text loaders file * test: adding missing tests about RAG loaders * refactor: QOL * fix: add missing uv syntax on DOCXLoader
This commit is contained in:
41
src/crewai_tools/adapters/rag_adapter.py
Normal file
41
src/crewai_tools/adapters/rag_adapter.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai_tools.rag.core import RAG
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
|
||||
class RAGAdapter(Adapter):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = "crewai_knowledge_base",
|
||||
persist_directory: Optional[str] = None,
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
top_k: int = 5,
|
||||
embedding_api_key: Optional[str] = None,
|
||||
**embedding_kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Prepare embedding configuration
|
||||
embedding_config = {
|
||||
"api_key": embedding_api_key,
|
||||
**embedding_kwargs
|
||||
}
|
||||
|
||||
self._adapter = RAG(
|
||||
collection_name=collection_name,
|
||||
persist_directory=persist_directory,
|
||||
embedding_model=embedding_model,
|
||||
top_k=top_k,
|
||||
embedding_config=embedding_config
|
||||
)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
return self._adapter.query(question)
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._adapter.add(*args, **kwargs)
|
||||
8
src/crewai_tools/rag/__init__.py
Normal file
8
src/crewai_tools/rag/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from crewai_tools.rag.core import RAG, EmbeddingService
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
__all__ = [
|
||||
"RAG",
|
||||
"EmbeddingService",
|
||||
"DataType",
|
||||
]
|
||||
37
src/crewai_tools/rag/base_loader.py
Normal file
37
src/crewai_tools/rag/base_loader.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.misc import compute_sha256
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class LoaderResult(BaseModel):
|
||||
content: str = Field(description="The text content of the source")
|
||||
source: str = Field(description="The source of the content", default="unknown")
|
||||
metadata: Dict[str, Any] = Field(description="The metadata of the source", default_factory=dict)
|
||||
doc_id: str = Field(description="The id of the document")
|
||||
|
||||
|
||||
class BaseLoader(ABC):
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
self.config = config or {}
|
||||
|
||||
@abstractmethod
|
||||
def load(self, content: SourceContent, **kwargs) -> LoaderResult:
|
||||
...
|
||||
|
||||
def generate_doc_id(self, source_ref: str | None = None, content: str | None = None) -> str:
|
||||
"""
|
||||
Generate a unique document id based on the source reference and content.
|
||||
If the source reference is not provided, the content is used as the source reference.
|
||||
If the content is not provided, the source reference is used as the content.
|
||||
If both are provided, the source reference is used as the content.
|
||||
|
||||
Both are optional because the TEXT content type does not have a source reference. In this case, the content is used as the source reference.
|
||||
"""
|
||||
|
||||
source_ref = source_ref or ""
|
||||
content = content or ""
|
||||
|
||||
return compute_sha256(source_ref + content)
|
||||
15
src/crewai_tools/rag/chunkers/__init__.py
Normal file
15
src/crewai_tools/rag/chunkers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from crewai_tools.rag.chunkers.default_chunker import DefaultChunker
|
||||
from crewai_tools.rag.chunkers.text_chunker import TextChunker, DocxChunker, MdxChunker
|
||||
from crewai_tools.rag.chunkers.structured_chunker import CsvChunker, JsonChunker, XmlChunker
|
||||
|
||||
__all__ = [
|
||||
"BaseChunker",
|
||||
"DefaultChunker",
|
||||
"TextChunker",
|
||||
"DocxChunker",
|
||||
"MdxChunker",
|
||||
"CsvChunker",
|
||||
"JsonChunker",
|
||||
"XmlChunker",
|
||||
]
|
||||
167
src/crewai_tools/rag/chunkers/base_chunker.py
Normal file
167
src/crewai_tools/rag/chunkers/base_chunker.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from typing import List, Optional
|
||||
import re
|
||||
|
||||
class RecursiveCharacterTextSplitter:
|
||||
"""
|
||||
A text splitter that recursively splits text based on a hierarchy of separators.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
separators: Optional[List[str]] = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the RecursiveCharacterTextSplitter.
|
||||
|
||||
Args:
|
||||
chunk_size: Maximum size of each chunk
|
||||
chunk_overlap: Number of characters to overlap between chunks
|
||||
separators: List of separators to use for splitting (in order of preference)
|
||||
keep_separator: Whether to keep the separator in the split text
|
||||
"""
|
||||
if chunk_overlap >= chunk_size:
|
||||
raise ValueError(f"Chunk overlap ({chunk_overlap}) cannot be >= chunk size ({chunk_size})")
|
||||
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._keep_separator = keep_separator
|
||||
|
||||
self._separators = separators or [
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
return self._split_text(text, self._separators)
|
||||
|
||||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||
separator = separators[-1]
|
||||
new_separators = []
|
||||
|
||||
for i, sep in enumerate(separators):
|
||||
if sep == "":
|
||||
separator = sep
|
||||
break
|
||||
if re.search(re.escape(sep), text):
|
||||
separator = sep
|
||||
new_separators = separators[i + 1:]
|
||||
break
|
||||
|
||||
splits = self._split_text_with_separator(text, separator)
|
||||
|
||||
good_splits = []
|
||||
|
||||
for split in splits:
|
||||
if len(split) < self._chunk_size:
|
||||
good_splits.append(split)
|
||||
else:
|
||||
if new_separators:
|
||||
other_info = self._split_text(split, new_separators)
|
||||
good_splits.extend(other_info)
|
||||
else:
|
||||
good_splits.extend(self._split_by_characters(split))
|
||||
|
||||
return self._merge_splits(good_splits, separator)
|
||||
|
||||
def _split_text_with_separator(self, text: str, separator: str) -> List[str]:
|
||||
if separator == "":
|
||||
return list(text)
|
||||
|
||||
if self._keep_separator and separator in text:
|
||||
parts = text.split(separator)
|
||||
splits = []
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if i == 0:
|
||||
splits.append(part)
|
||||
elif i == len(parts) - 1:
|
||||
if part:
|
||||
splits.append(separator + part)
|
||||
else:
|
||||
if part:
|
||||
splits.append(separator + part)
|
||||
else:
|
||||
if splits:
|
||||
splits[-1] += separator
|
||||
|
||||
return [s for s in splits if s]
|
||||
else:
|
||||
return text.split(separator)
|
||||
|
||||
def _split_by_characters(self, text: str) -> List[str]:
|
||||
chunks = []
|
||||
for i in range(0, len(text), self._chunk_size):
|
||||
chunks.append(text[i:i + self._chunk_size])
|
||||
return chunks
|
||||
|
||||
def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
|
||||
"""Merge splits into chunks with proper overlap."""
|
||||
docs = []
|
||||
current_doc = []
|
||||
total = 0
|
||||
|
||||
for split in splits:
|
||||
split_len = len(split)
|
||||
|
||||
if total + split_len > self._chunk_size and current_doc:
|
||||
if separator == "":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
|
||||
# Handle overlap by keeping some of the previous content
|
||||
while total > self._chunk_overlap and len(current_doc) > 1:
|
||||
removed = current_doc.pop(0)
|
||||
total -= len(removed)
|
||||
if separator != "":
|
||||
total -= len(separator)
|
||||
|
||||
current_doc.append(split)
|
||||
total += split_len
|
||||
if separator != "" and len(current_doc) > 1:
|
||||
total += len(separator)
|
||||
|
||||
if current_doc:
|
||||
if separator == "":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
class BaseChunker:
|
||||
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
"""
|
||||
Initialize the Chunker
|
||||
|
||||
Args:
|
||||
chunk_size: Maximum size of each chunk
|
||||
chunk_overlap: Number of characters to overlap between chunks
|
||||
separators: List of separators to use for splitting
|
||||
keep_separator: Whether to keep separators in the chunks
|
||||
"""
|
||||
|
||||
self._splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separators=separators,
|
||||
keep_separator=keep_separator,
|
||||
)
|
||||
|
||||
|
||||
def chunk(self, text: str) -> List[str]:
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
return self._splitter.split_text(text)
|
||||
6
src/crewai_tools/rag/chunkers/default_chunker.py
Normal file
6
src/crewai_tools/rag/chunkers/default_chunker.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
class DefaultChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 20, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
49
src/crewai_tools/rag/chunkers/structured_chunker.py
Normal file
49
src/crewai_tools/rag/chunkers/structured_chunker.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class CsvChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 1200, chunk_overlap: int = 100, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\nRow ", # Row boundaries (from CSVLoader format)
|
||||
"\n", # Line breaks
|
||||
" | ", # Column separators
|
||||
", ", # Comma separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class JsonChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n", # Object/array boundaries
|
||||
"\n", # Line breaks
|
||||
"},", # Object endings
|
||||
"],", # Array endings
|
||||
", ", # Property separators
|
||||
": ", # Key-value separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class XmlChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n", # Element boundaries
|
||||
"\n", # Line breaks
|
||||
">", # Tag endings
|
||||
". ", # Sentence endings (for text content)
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
", ", # Comma separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
59
src/crewai_tools/rag/chunkers/text_chunker.py
Normal file
59
src/crewai_tools/rag/chunkers/text_chunker.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class TextChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 1500, chunk_overlap: int = 150, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Multiple line breaks (sections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class DocxChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Multiple line breaks (major sections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class MdxChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 3000, chunk_overlap: int = 300, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n## ", # H2 headers (major sections)
|
||||
"\n### ", # H3 headers (subsections)
|
||||
"\n#### ", # H4 headers (sub-subsections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n```", # Code block boundaries
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
20
src/crewai_tools/rag/chunkers/web_chunker.py
Normal file
20
src/crewai_tools/rag/chunkers/web_chunker.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class WebsiteChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Major section breaks
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
232
src/crewai_tools/rag/core.py
Normal file
232
src/crewai_tools/rag/core.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
import litellm
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.rag.misc import compute_sha256
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
def __init__(self, model: str = "text-embedding-3-small", **kwargs):
|
||||
self.model = model
|
||||
self.kwargs = kwargs
|
||||
|
||||
def embed_text(self, text: str) -> List[float]:
|
||||
try:
|
||||
response = litellm.embedding(
|
||||
model=self.model,
|
||||
input=[text],
|
||||
**self.kwargs
|
||||
)
|
||||
return response.data[0]['embedding']
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embedding: {e}")
|
||||
raise
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = litellm.embedding(
|
||||
model=self.model,
|
||||
input=texts,
|
||||
**self.kwargs
|
||||
)
|
||||
return [data['embedding'] for data in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating batch embeddings: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
content: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
data_type: DataType = DataType.TEXT
|
||||
source: Optional[str] = None
|
||||
|
||||
|
||||
class RAG(Adapter):
|
||||
collection_name: str = "crewai_knowledge_base"
|
||||
persist_directory: Optional[str] = None
|
||||
embedding_model: str = "text-embedding-3-large"
|
||||
summarize: bool = False
|
||||
top_k: int = 5
|
||||
embedding_config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
_client: Any = PrivateAttr()
|
||||
_collection: Any = PrivateAttr()
|
||||
_embedding_service: EmbeddingService = PrivateAttr()
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
try:
|
||||
if self.persist_directory:
|
||||
self._client = chromadb.PersistentClient(path=self.persist_directory)
|
||||
else:
|
||||
self._client = chromadb.Client()
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"hnsw:space": "cosine", "description": "CrewAI Knowledge Base"}
|
||||
)
|
||||
|
||||
self._embedding_service = EmbeddingService(model=self.embedding_model, **self.embedding_config)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ChromaDB: {e}")
|
||||
raise
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
def add(
|
||||
self,
|
||||
content: str | Path,
|
||||
data_type: Optional[Union[str, DataType]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
loader: Optional[BaseLoader] = None,
|
||||
chunker: Optional[BaseChunker] = None,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
source_content = SourceContent(content)
|
||||
|
||||
data_type = self._get_data_type(data_type=data_type, content=source_content)
|
||||
|
||||
if not loader:
|
||||
loader = data_type.get_loader()
|
||||
|
||||
if not chunker:
|
||||
chunker = data_type.get_chunker()
|
||||
|
||||
loader_result = loader.load(source_content)
|
||||
doc_id = loader_result.doc_id
|
||||
|
||||
existing_doc = self._collection.get(where={"source": source_content.source_ref}, limit=1)
|
||||
existing_doc_id = existing_doc and existing_doc['metadatas'][0]['doc_id'] if existing_doc['metadatas'] else None
|
||||
|
||||
if existing_doc_id == doc_id:
|
||||
logger.warning(f"Document with source {loader_result.source} already exists")
|
||||
return
|
||||
|
||||
# Document with same source ref does exists but the content has changed, deleting the oldest reference
|
||||
if existing_doc_id and existing_doc_id != loader_result.doc_id:
|
||||
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
|
||||
self._collection.delete(where={"doc_id": existing_doc_id})
|
||||
|
||||
documents = []
|
||||
|
||||
chunks = chunker.chunk(loader_result.content)
|
||||
for i, chunk in enumerate(chunks):
|
||||
doc_metadata = (metadata or {}).copy()
|
||||
doc_metadata['chunk_index'] = i
|
||||
documents.append(Document(
|
||||
id=compute_sha256(chunk),
|
||||
content=chunk,
|
||||
metadata=doc_metadata,
|
||||
data_type=data_type,
|
||||
source=loader_result.source
|
||||
))
|
||||
|
||||
if not documents:
|
||||
logger.warning("No documents to add")
|
||||
return
|
||||
|
||||
contents = [doc.content for doc in documents]
|
||||
try:
|
||||
embeddings = self._embedding_service.embed_batch(contents)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embeddings: {e}")
|
||||
return
|
||||
|
||||
ids = [doc.id for doc in documents]
|
||||
metadatas = []
|
||||
|
||||
for doc in documents:
|
||||
doc_metadata = doc.metadata.copy()
|
||||
doc_metadata.update({
|
||||
"data_type": doc.data_type.value,
|
||||
"source": doc.source,
|
||||
"doc_id": doc_id
|
||||
})
|
||||
metadatas.append(doc_metadata)
|
||||
|
||||
try:
|
||||
self._collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=contents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
logger.info(f"Added {len(documents)} documents to knowledge base")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
|
||||
def query(self, question: str, where: Optional[Dict[str, Any]] = None) -> str:
|
||||
try:
|
||||
question_embedding = self._embedding_service.embed_text(question)
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=[question_embedding],
|
||||
n_results=self.top_k,
|
||||
where=where,
|
||||
include=["documents", "metadatas", "distances"]
|
||||
)
|
||||
|
||||
if not results or not results.get("documents") or not results["documents"][0]:
|
||||
return "No relevant content found."
|
||||
|
||||
documents = results["documents"][0]
|
||||
metadatas = results.get("metadatas", [None])[0] or []
|
||||
distances = results.get("distances", [None])[0] or []
|
||||
|
||||
# Return sources with relevance scores
|
||||
formatted_results = []
|
||||
for i, doc in enumerate(documents):
|
||||
metadata = metadatas[i] if i < len(metadatas) else {}
|
||||
distance = distances[i] if i < len(distances) else 1.0
|
||||
source = metadata.get("source", "unknown") if metadata else "unknown"
|
||||
score = 1 - distance if distance is not None else 0 # Convert distance to similarity
|
||||
formatted_results.append(f"[Source: {source}, Relevance: {score:.3f}]\n{doc}")
|
||||
|
||||
return "\n\n".join(formatted_results)
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {e}")
|
||||
return f"Error querying knowledge base: {e}"
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
try:
|
||||
self._client.delete_collection(self.collection_name)
|
||||
logger.info(f"Deleted collection: {self.collection_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
||||
def get_collection_info(self) -> Dict[str, Any]:
|
||||
try:
|
||||
count = self._collection.count()
|
||||
return {
|
||||
"name": self.collection_name,
|
||||
"count": count,
|
||||
"embedding_model": self.embedding_model
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get collection info: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _get_data_type(self, content: SourceContent, data_type: str | DataType | None = None) -> DataType:
|
||||
try:
|
||||
if isinstance(data_type, str):
|
||||
return DataType(data_type)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return content.data_type
|
||||
137
src/crewai_tools/rag/data_types.py
Normal file
137
src/crewai_tools/rag/data_types.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
|
||||
class DataType(str, Enum):
|
||||
PDF_FILE = "pdf_file"
|
||||
TEXT_FILE = "text_file"
|
||||
CSV = "csv"
|
||||
JSON = "json"
|
||||
XML = "xml"
|
||||
DOCX = "docx"
|
||||
MDX = "mdx"
|
||||
|
||||
# Database types
|
||||
MYSQL = "mysql"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
# Repository types
|
||||
GITHUB = "github"
|
||||
DIRECTORY = "directory"
|
||||
|
||||
# Web types
|
||||
WEBSITE = "website"
|
||||
DOCS_SITE = "docs_site"
|
||||
|
||||
# Raw types
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
def get_chunker(self) -> BaseChunker:
|
||||
from importlib import import_module
|
||||
|
||||
chunkers = {
|
||||
DataType.TEXT_FILE: ("text_chunker", "TextChunker"),
|
||||
DataType.TEXT: ("text_chunker", "TextChunker"),
|
||||
DataType.DOCX: ("text_chunker", "DocxChunker"),
|
||||
DataType.MDX: ("text_chunker", "MdxChunker"),
|
||||
|
||||
# Structured formats
|
||||
DataType.CSV: ("structured_chunker", "CsvChunker"),
|
||||
DataType.JSON: ("structured_chunker", "JsonChunker"),
|
||||
DataType.XML: ("structured_chunker", "XmlChunker"),
|
||||
|
||||
DataType.WEBSITE: ("web_chunker", "WebsiteChunker"),
|
||||
}
|
||||
|
||||
module_name, class_name = chunkers.get(self, ("default_chunker", "DefaultChunker"))
|
||||
module_path = f"crewai_tools.rag.chunkers.{module_name}"
|
||||
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading chunker for {self}: {e}")
|
||||
|
||||
def get_loader(self) -> BaseLoader:
|
||||
from importlib import import_module
|
||||
|
||||
loaders = {
|
||||
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||
DataType.TEXT: ("text_loader", "TextLoader"),
|
||||
DataType.XML: ("xml_loader", "XMLLoader"),
|
||||
DataType.WEBSITE: ("webpage_loader", "WebPageLoader"),
|
||||
DataType.MDX: ("mdx_loader", "MDXLoader"),
|
||||
DataType.JSON: ("json_loader", "JSONLoader"),
|
||||
DataType.DOCX: ("docx_loader", "DOCXLoader"),
|
||||
DataType.CSV: ("csv_loader", "CSVLoader"),
|
||||
DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"),
|
||||
}
|
||||
|
||||
module_name, class_name = loaders.get(self, ("text_loader", "TextLoader"))
|
||||
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading loader for {self}: {e}")
|
||||
|
||||
class DataTypes:
|
||||
@staticmethod
|
||||
def from_content(content: str | Path | None = None) -> DataType:
|
||||
if content is None:
|
||||
return DataType.TEXT
|
||||
|
||||
if isinstance(content, Path):
|
||||
content = str(content)
|
||||
|
||||
is_url = False
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
url = urlparse(content)
|
||||
is_url = (url.scheme and url.netloc) or url.scheme == "file"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_file_type(path: str) -> DataType | None:
|
||||
mapping = {
|
||||
".pdf": DataType.PDF_FILE,
|
||||
".csv": DataType.CSV,
|
||||
".mdx": DataType.MDX,
|
||||
".md": DataType.MDX,
|
||||
".docx": DataType.DOCX,
|
||||
".json": DataType.JSON,
|
||||
".xml": DataType.XML,
|
||||
".txt": DataType.TEXT_FILE,
|
||||
}
|
||||
for ext, dtype in mapping.items():
|
||||
if path.endswith(ext):
|
||||
return dtype
|
||||
return None
|
||||
|
||||
if is_url:
|
||||
dtype = get_file_type(url.path)
|
||||
if dtype:
|
||||
return dtype
|
||||
|
||||
if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
|
||||
return DataType.DOCS_SITE
|
||||
if "github.com" in url.netloc:
|
||||
return DataType.GITHUB
|
||||
|
||||
return DataType.WEBSITE
|
||||
|
||||
if os.path.isfile(content):
|
||||
dtype = get_file_type(content)
|
||||
if dtype:
|
||||
return dtype
|
||||
|
||||
if os.path.exists(content):
|
||||
return DataType.TEXT_FILE
|
||||
elif os.path.isdir(content):
|
||||
return DataType.DIRECTORY
|
||||
|
||||
return DataType.TEXT
|
||||
20
src/crewai_tools/rag/loaders/__init__.py
Normal file
20
src/crewai_tools/rag/loaders/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
|
||||
from crewai_tools.rag.loaders.xml_loader import XMLLoader
|
||||
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
|
||||
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
|
||||
from crewai_tools.rag.loaders.json_loader import JSONLoader
|
||||
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
|
||||
from crewai_tools.rag.loaders.csv_loader import CSVLoader
|
||||
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
|
||||
|
||||
__all__ = [
|
||||
"TextFileLoader",
|
||||
"TextLoader",
|
||||
"XMLLoader",
|
||||
"WebPageLoader",
|
||||
"MDXLoader",
|
||||
"JSONLoader",
|
||||
"DOCXLoader",
|
||||
"CSVLoader",
|
||||
"DirectoryLoader",
|
||||
]
|
||||
72
src/crewai_tools/rag/loaders/csv_loader.py
Normal file
72
src/crewai_tools/rag/loaders/csv_loader.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import csv
|
||||
from io import StringIO
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class CSVLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
|
||||
content_str = source_content.source
|
||||
if source_content.is_url():
|
||||
content_str = self._load_from_url(content_str, kwargs)
|
||||
elif source_content.path_exists():
|
||||
content_str = self._load_from_file(content_str)
|
||||
|
||||
return self._parse_csv(content_str, source_ref)
|
||||
|
||||
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "text/csv, application/csv, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)"
|
||||
})
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching CSV from URL {url}: {str(e)}")
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
def _parse_csv(self, content: str, source_ref: str) -> LoaderResult:
|
||||
try:
|
||||
csv_reader = csv.DictReader(StringIO(content))
|
||||
|
||||
text_parts = []
|
||||
headers = csv_reader.fieldnames
|
||||
|
||||
if headers:
|
||||
text_parts.append("Headers: " + " | ".join(headers))
|
||||
text_parts.append("-" * 50)
|
||||
|
||||
for row_num, row in enumerate(csv_reader, 1):
|
||||
row_text = " | ".join([f"{k}: {v}" for k, v in row.items() if v])
|
||||
text_parts.append(f"Row {row_num}: {row_text}")
|
||||
|
||||
text = "\n".join(text_parts)
|
||||
|
||||
metadata = {
|
||||
"format": "csv",
|
||||
"columns": headers,
|
||||
"rows": len(text_parts) - 2 if headers else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
text = content
|
||||
metadata = {"format": "csv", "parse_error": str(e)}
|
||||
|
||||
return LoaderResult(
|
||||
content=text,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
|
||||
)
|
||||
142
src/crewai_tools/rag/loaders/directory_loader.py
Normal file
142
src/crewai_tools/rag/loaders/directory_loader.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class DirectoryLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""
|
||||
Load and process all files from a directory recursively.
|
||||
|
||||
Args:
|
||||
source: Directory path or URL to a directory listing
|
||||
**kwargs: Additional options:
|
||||
- recursive: bool (default True) - Whether to search recursively
|
||||
- include_extensions: list - Only include files with these extensions
|
||||
- exclude_extensions: list - Exclude files with these extensions
|
||||
- max_files: int - Maximum number of files to process
|
||||
"""
|
||||
source_ref = source_content.source_ref
|
||||
|
||||
if source_content.is_url():
|
||||
raise ValueError("URL directory loading is not supported. Please provide a local directory path.")
|
||||
|
||||
if not os.path.exists(source_ref):
|
||||
raise FileNotFoundError(f"Directory does not exist: {source_ref}")
|
||||
|
||||
if not os.path.isdir(source_ref):
|
||||
raise ValueError(f"Path is not a directory: {source_ref}")
|
||||
|
||||
return self._process_directory(source_ref, kwargs)
|
||||
|
||||
def _process_directory(self, dir_path: str, kwargs: dict) -> LoaderResult:
|
||||
recursive = kwargs.get("recursive", True)
|
||||
include_extensions = kwargs.get("include_extensions", None)
|
||||
exclude_extensions = kwargs.get("exclude_extensions", None)
|
||||
max_files = kwargs.get("max_files", None)
|
||||
|
||||
files = self._find_files(dir_path, recursive, include_extensions, exclude_extensions)
|
||||
|
||||
if max_files and len(files) > max_files:
|
||||
files = files[:max_files]
|
||||
|
||||
all_contents = []
|
||||
processed_files = []
|
||||
errors = []
|
||||
|
||||
for file_path in files:
|
||||
try:
|
||||
result = self._process_single_file(file_path)
|
||||
if result:
|
||||
all_contents.append(f"=== File: {file_path} ===\n{result.content}")
|
||||
processed_files.append({
|
||||
"path": file_path,
|
||||
"metadata": result.metadata,
|
||||
"source": result.source
|
||||
})
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing {file_path}: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
all_contents.append(f"=== File: {file_path} (ERROR) ===\n{error_msg}")
|
||||
|
||||
combined_content = "\n\n".join(all_contents)
|
||||
|
||||
metadata = {
|
||||
"format": "directory",
|
||||
"directory_path": dir_path,
|
||||
"total_files": len(files),
|
||||
"processed_files": len(processed_files),
|
||||
"errors": len(errors),
|
||||
"file_details": processed_files,
|
||||
"error_details": errors
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=combined_content,
|
||||
source=dir_path,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=dir_path, content=combined_content)
|
||||
)
|
||||
|
||||
def _find_files(self, dir_path: str, recursive: bool, include_ext: List[str] | None = None, exclude_ext: List[str] | None = None) -> List[str]:
|
||||
"""Find all files in directory matching criteria."""
|
||||
files = []
|
||||
|
||||
if recursive:
|
||||
for root, dirs, filenames in os.walk(dir_path):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.')]
|
||||
|
||||
for filename in filenames:
|
||||
if self._should_include_file(filename, include_ext, exclude_ext):
|
||||
files.append(os.path.join(root, filename))
|
||||
else:
|
||||
try:
|
||||
for item in os.listdir(dir_path):
|
||||
item_path = os.path.join(dir_path, item)
|
||||
if os.path.isfile(item_path) and self._should_include_file(item, include_ext, exclude_ext):
|
||||
files.append(item_path)
|
||||
except PermissionError:
|
||||
pass
|
||||
|
||||
return sorted(files)
|
||||
|
||||
def _should_include_file(self, filename: str, include_ext: List[str] = None, exclude_ext: List[str] = None) -> bool:
|
||||
"""Determine if a file should be included based on criteria."""
|
||||
if filename.startswith('.'):
|
||||
return False
|
||||
|
||||
_, ext = os.path.splitext(filename.lower())
|
||||
|
||||
if include_ext:
|
||||
if ext not in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in include_ext]:
|
||||
return False
|
||||
|
||||
if exclude_ext:
|
||||
if ext in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in exclude_ext]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _process_single_file(self, file_path: str) -> LoaderResult:
|
||||
from crewai_tools.rag.data_types import DataTypes
|
||||
|
||||
data_type = DataTypes.from_content(Path(file_path))
|
||||
|
||||
loader = data_type.get_loader()
|
||||
|
||||
result = loader.load(SourceContent(file_path))
|
||||
|
||||
if result.metadata is None:
|
||||
result.metadata = {}
|
||||
|
||||
result.metadata.update({
|
||||
"file_path": file_path,
|
||||
"file_size": os.path.getsize(file_path),
|
||||
"data_type": str(data_type),
|
||||
"loader_type": loader.__class__.__name__
|
||||
})
|
||||
|
||||
return result
|
||||
72
src/crewai_tools/rag/loaders/docx_loader.py
Normal file
72
src/crewai_tools/rag/loaders/docx_loader.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class DOCXLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
except ImportError:
|
||||
raise ImportError("python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]")
|
||||
|
||||
source_ref = source_content.source_ref
|
||||
|
||||
if source_content.is_url():
|
||||
temp_file = self._download_from_url(source_ref, kwargs)
|
||||
try:
|
||||
return self._load_from_file(temp_file, source_ref, DocxDocument)
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
elif source_content.path_exists():
|
||||
return self._load_from_file(source_ref, source_ref, DocxDocument)
|
||||
else:
|
||||
raise ValueError(f"Source must be a valid file path or URL, got: {source_content.source}")
|
||||
|
||||
def _download_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools DOCXLoader)"
|
||||
})
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
# Create temporary file to save the DOCX content
|
||||
with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as temp_file:
|
||||
temp_file.write(response.content)
|
||||
return temp_file.name
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching DOCX from URL {url}: {str(e)}")
|
||||
|
||||
def _load_from_file(self, file_path: str, source_ref: str, DocxDocument) -> LoaderResult:
|
||||
try:
|
||||
doc = DocxDocument(file_path)
|
||||
|
||||
text_parts = []
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text_parts.append(paragraph.text)
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
metadata = {
|
||||
"format": "docx",
|
||||
"paragraphs": len(doc.paragraphs),
|
||||
"tables": len(doc.tables)
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=content)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading DOCX file: {str(e)}")
|
||||
69
src/crewai_tools/rag/loaders/json_loader.py
Normal file
69
src/crewai_tools/rag/loaders/json_loader.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import json
|
||||
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
|
||||
|
||||
class JSONLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
content = source_content.source
|
||||
|
||||
if source_content.is_url():
|
||||
content = self._load_from_url(source_ref, kwargs)
|
||||
elif source_content.path_exists():
|
||||
content = self._load_from_file(source_ref)
|
||||
|
||||
return self._parse_json(content, source_ref)
|
||||
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)"
|
||||
})
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text if not self._is_json_response(response) else json.dumps(response.json(), indent=2)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching JSON from URL {url}: {str(e)}")
|
||||
|
||||
def _is_json_response(self, response) -> bool:
|
||||
try:
|
||||
response.json()
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
def _parse_json(self, content: str, source_ref: str) -> LoaderResult:
|
||||
try:
|
||||
data = json.loads(content)
|
||||
if isinstance(data, dict):
|
||||
text = "\n".join(f"{k}: {json.dumps(v, indent=0)}" for k, v in data.items())
|
||||
elif isinstance(data, list):
|
||||
text = "\n".join(json.dumps(item, indent=0) for item in data)
|
||||
else:
|
||||
text = json.dumps(data, indent=0)
|
||||
|
||||
metadata = {
|
||||
"format": "json",
|
||||
"type": type(data).__name__,
|
||||
"size": len(data) if isinstance(data, (list, dict)) else 1
|
||||
}
|
||||
except json.JSONDecodeError as e:
|
||||
text = content
|
||||
metadata = {"format": "json", "parse_error": str(e)}
|
||||
|
||||
return LoaderResult(
|
||||
content=text,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
|
||||
)
|
||||
59
src/crewai_tools/rag/loaders/mdx_loader.py
Normal file
59
src/crewai_tools/rag/loaders/mdx_loader.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import re
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
class MDXLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
content = source_content.source
|
||||
|
||||
if source_content.is_url():
|
||||
content = self._load_from_url(source_ref, kwargs)
|
||||
elif source_content.path_exists():
|
||||
content = self._load_from_file(source_ref)
|
||||
|
||||
return self._parse_mdx(content, source_ref)
|
||||
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "text/markdown, text/x-markdown, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)"
|
||||
})
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching MDX from URL {url}: {str(e)}")
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
def _parse_mdx(self, content: str, source_ref: str) -> LoaderResult:
|
||||
cleaned_content = content
|
||||
|
||||
# Remove import statements
|
||||
cleaned_content = re.sub(r'^import\s+.*?\n', '', cleaned_content, flags=re.MULTILINE)
|
||||
|
||||
# Remove export statements
|
||||
cleaned_content = re.sub(r'^export\s+.*?(?:\n|$)', '', cleaned_content, flags=re.MULTILINE)
|
||||
|
||||
# Remove JSX tags (simple approach)
|
||||
cleaned_content = re.sub(r'<[^>]+>', '', cleaned_content)
|
||||
|
||||
# Clean up extra whitespace
|
||||
cleaned_content = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_content)
|
||||
cleaned_content = cleaned_content.strip()
|
||||
|
||||
metadata = {"format": "mdx"}
|
||||
return LoaderResult(
|
||||
content=cleaned_content,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=cleaned_content)
|
||||
)
|
||||
28
src/crewai_tools/rag/loaders/text_loader.py
Normal file
28
src/crewai_tools/rag/loaders/text_loader.py
Normal file
@@ -0,0 +1,28 @@
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class TextFileLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
if not source_content.path_exists():
|
||||
raise FileNotFoundError(f"The following file does not exist: {source_content.source}")
|
||||
|
||||
with open(source_content.source, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=source_ref,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=content)
|
||||
)
|
||||
|
||||
|
||||
class TextLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
return LoaderResult(
|
||||
content=source_content.source,
|
||||
source=source_content.source_ref,
|
||||
doc_id=self.generate_doc_id(content=source_content.source)
|
||||
)
|
||||
47
src/crewai_tools/rag/loaders/webpage_loader.py
Normal file
47
src/crewai_tools/rag/loaders/webpage_loader.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import re
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
class WebPageLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
url = source_content.source
|
||||
headers = kwargs.get("headers", {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
})
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=15, headers=headers)
|
||||
response.encoding = response.apparent_encoding
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.decompose()
|
||||
|
||||
text = soup.get_text(" ")
|
||||
text = re.sub("[ \t]+", " ", text)
|
||||
text = re.sub("\\s+\n\\s+", "\n", text)
|
||||
text = text.strip()
|
||||
|
||||
title = soup.title.string.strip() if soup.title and soup.title.string else ""
|
||||
metadata = {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"status_code": response.status_code,
|
||||
"content_type": response.headers.get("content-type", "")
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=text,
|
||||
source=url,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=url, content=text)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading webpage {url}: {str(e)}")
|
||||
61
src/crewai_tools/rag/loaders/xml_loader.py
Normal file
61
src/crewai_tools/rag/loaders/xml_loader.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
class XMLLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
content = source_content.source
|
||||
|
||||
if source_content.is_url():
|
||||
content = self._load_from_url(source_ref, kwargs)
|
||||
elif os.path.exists(source_ref):
|
||||
content = self._load_from_file(source_ref)
|
||||
|
||||
return self._parse_xml(content, source_ref)
|
||||
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "application/xml, text/xml, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)"
|
||||
})
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching XML from URL {url}: {str(e)}")
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
def _parse_xml(self, content: str, source_ref: str) -> LoaderResult:
|
||||
try:
|
||||
if content.strip().startswith('<'):
|
||||
root = ET.fromstring(content)
|
||||
else:
|
||||
root = ET.parse(source_ref).getroot()
|
||||
|
||||
text_parts = []
|
||||
for text_content in root.itertext():
|
||||
if text_content and text_content.strip():
|
||||
text_parts.append(text_content.strip())
|
||||
|
||||
text = "\n".join(text_parts)
|
||||
metadata = {"format": "xml", "root_tag": root.tag}
|
||||
except ET.ParseError as e:
|
||||
text = content
|
||||
metadata = {"format": "xml", "parse_error": str(e)}
|
||||
|
||||
return LoaderResult(
|
||||
content=text,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
|
||||
)
|
||||
4
src/crewai_tools/rag/misc.py
Normal file
4
src/crewai_tools/rag/misc.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import hashlib
|
||||
|
||||
def compute_sha256(content: str) -> str:
|
||||
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
46
src/crewai_tools/rag/source_content.py
Normal file
46
src/crewai_tools/rag/source_content.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from typing import TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from functools import cached_property
|
||||
|
||||
from crewai_tools.rag.misc import compute_sha256
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class SourceContent:
|
||||
def __init__(self, source: str | Path):
|
||||
self.source = str(source)
|
||||
|
||||
def is_url(self) -> bool:
|
||||
if not isinstance(self.source, str):
|
||||
return False
|
||||
try:
|
||||
parsed_url = urlparse(self.source)
|
||||
return bool(parsed_url.scheme and parsed_url.netloc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def path_exists(self) -> bool:
|
||||
return os.path.exists(self.source)
|
||||
|
||||
@cached_property
|
||||
def data_type(self) -> "DataType":
|
||||
from crewai_tools.rag.data_types import DataTypes
|
||||
|
||||
return DataTypes.from_content(self.source)
|
||||
|
||||
@cached_property
|
||||
def source_ref(self) -> str:
|
||||
""""
|
||||
Returns the source reference for the content.
|
||||
If the content is a URL or a local file, returns the source.
|
||||
Otherwise, returns the hash of the content.
|
||||
"""
|
||||
|
||||
if self.is_url() or self.path_exists():
|
||||
return self.source
|
||||
|
||||
return compute_sha256(self.source)
|
||||
0
tests/rag/__init__.py
Normal file
0
tests/rag/__init__.py
Normal file
130
tests/rag/test_csv_loader.py
Normal file
130
tests/rag/test_csv_loader.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from crewai_tools.rag.loaders.csv_loader import CSVLoader
|
||||
from crewai_tools.rag.base_loader import LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_csv_file():
|
||||
created_files = []
|
||||
|
||||
def _create(content: str):
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False)
|
||||
f.write(content)
|
||||
f.close()
|
||||
created_files.append(f.name)
|
||||
return f.name
|
||||
|
||||
yield _create
|
||||
|
||||
for path in created_files:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
class TestCSVLoader:
|
||||
def test_load_csv_from_file(self, temp_csv_file):
|
||||
path = temp_csv_file("name,age,city\nJohn,25,New York\nJane,30,Chicago")
|
||||
loader = CSVLoader()
|
||||
result = loader.load(SourceContent(path))
|
||||
|
||||
assert isinstance(result, LoaderResult)
|
||||
assert "Headers: name | age | city" in result.content
|
||||
assert "Row 1: name: John | age: 25 | city: New York" in result.content
|
||||
assert "Row 2: name: Jane | age: 30 | city: Chicago" in result.content
|
||||
assert result.metadata == {
|
||||
"format": "csv",
|
||||
"columns": ["name", "age", "city"],
|
||||
"rows": 2,
|
||||
}
|
||||
assert result.source == path
|
||||
assert result.doc_id
|
||||
|
||||
def test_load_csv_with_empty_values(self, temp_csv_file):
|
||||
path = temp_csv_file("name,age,city\nJohn,,New York\n,30,")
|
||||
result = CSVLoader().load(SourceContent(path))
|
||||
|
||||
assert "Row 1: name: John | city: New York" in result.content
|
||||
assert "Row 2: age: 30" in result.content
|
||||
assert result.metadata["rows"] == 2
|
||||
|
||||
def test_load_csv_malformed(self, temp_csv_file):
|
||||
path = temp_csv_file("invalid,csv\nunclosed quote \"missing")
|
||||
result = CSVLoader().load(SourceContent(path))
|
||||
|
||||
assert "Headers: invalid | csv" in result.content
|
||||
assert 'Row 1: invalid: unclosed quote "missing' in result.content
|
||||
assert result.metadata["columns"] == ["invalid", "csv"]
|
||||
|
||||
def test_load_csv_empty_file(self, temp_csv_file):
|
||||
path = temp_csv_file("")
|
||||
result = CSVLoader().load(SourceContent(path))
|
||||
|
||||
assert result.content == ""
|
||||
assert result.metadata["rows"] == 0
|
||||
|
||||
def test_load_csv_text_input(self):
|
||||
raw_csv = "col1,col2\nvalue1,value2\nvalue3,value4"
|
||||
result = CSVLoader().load(SourceContent(raw_csv))
|
||||
|
||||
assert "Headers: col1 | col2" in result.content
|
||||
assert "Row 1: col1: value1 | col2: value2" in result.content
|
||||
assert "Row 2: col1: value3 | col2: value4" in result.content
|
||||
assert result.metadata["columns"] == ["col1", "col2"]
|
||||
assert result.metadata["rows"] == 2
|
||||
|
||||
def test_doc_id_is_deterministic(self, temp_csv_file):
|
||||
path = temp_csv_file("name,value\ntest,123")
|
||||
loader = CSVLoader()
|
||||
|
||||
result1 = loader.load(SourceContent(path))
|
||||
result2 = loader.load(SourceContent(path))
|
||||
|
||||
assert result1.doc_id == result2.doc_id
|
||||
|
||||
@patch("requests.get")
|
||||
def test_load_csv_from_url(self, mock_get):
|
||||
mock_get.return_value = Mock(
|
||||
text="name,value\ntest,123",
|
||||
raise_for_status=Mock(return_value=None)
|
||||
)
|
||||
|
||||
result = CSVLoader().load(SourceContent("https://example.com/data.csv"))
|
||||
|
||||
assert "Headers: name | value" in result.content
|
||||
assert "Row 1: name: test | value: 123" in result.content
|
||||
headers = mock_get.call_args[1]["headers"]
|
||||
assert "text/csv" in headers["Accept"]
|
||||
assert "crewai-tools CSVLoader" in headers["User-Agent"]
|
||||
|
||||
@patch("requests.get")
|
||||
def test_load_csv_with_custom_headers(self, mock_get):
|
||||
mock_get.return_value = Mock(
|
||||
text="data,value\ntest,456",
|
||||
raise_for_status=Mock(return_value=None)
|
||||
)
|
||||
headers = {"Authorization": "Bearer token", "Custom-Header": "value"}
|
||||
result = CSVLoader().load(SourceContent("https://example.com/data.csv"), headers=headers)
|
||||
|
||||
assert "Headers: data | value" in result.content
|
||||
assert mock_get.call_args[1]["headers"] == headers
|
||||
|
||||
@patch("requests.get")
|
||||
def test_csv_loader_handles_network_errors(self, mock_get):
|
||||
mock_get.side_effect = Exception("Network error")
|
||||
loader = CSVLoader()
|
||||
|
||||
with pytest.raises(ValueError, match="Error fetching CSV from URL"):
|
||||
loader.load(SourceContent("https://example.com/data.csv"))
|
||||
|
||||
@patch("requests.get")
|
||||
def test_csv_loader_handles_http_error(self, mock_get):
|
||||
mock_get.return_value = Mock()
|
||||
mock_get.return_value.raise_for_status.side_effect = Exception("404 Not Found")
|
||||
loader = CSVLoader()
|
||||
|
||||
with pytest.raises(ValueError, match="Error fetching CSV from URL"):
|
||||
loader.load(SourceContent("https://example.com/notfound.csv"))
|
||||
149
tests/rag/test_directory_loader.py
Normal file
149
tests/rag/test_directory_loader.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
|
||||
from crewai_tools.rag.base_loader import LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_directory():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield temp_dir
|
||||
|
||||
|
||||
class TestDirectoryLoader:
|
||||
def _create_file(self, directory, filename, content="test content"):
|
||||
path = os.path.join(directory, filename)
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
return path
|
||||
|
||||
def test_load_non_recursive(self, temp_directory):
|
||||
self._create_file(temp_directory, "file1.txt")
|
||||
self._create_file(temp_directory, "file2.txt")
|
||||
subdir = os.path.join(temp_directory, "subdir")
|
||||
os.makedirs(subdir)
|
||||
self._create_file(subdir, "file3.txt")
|
||||
|
||||
loader = DirectoryLoader()
|
||||
result = loader.load(SourceContent(temp_directory), recursive=False)
|
||||
|
||||
assert isinstance(result, LoaderResult)
|
||||
assert "file1.txt" in result.content
|
||||
assert "file2.txt" in result.content
|
||||
assert "file3.txt" not in result.content
|
||||
assert result.metadata["total_files"] == 2
|
||||
|
||||
def test_load_recursive(self, temp_directory):
|
||||
self._create_file(temp_directory, "file1.txt")
|
||||
nested = os.path.join(temp_directory, "subdir", "nested")
|
||||
os.makedirs(nested)
|
||||
self._create_file(os.path.join(temp_directory, "subdir"), "file2.txt")
|
||||
self._create_file(nested, "file3.txt")
|
||||
|
||||
loader = DirectoryLoader()
|
||||
result = loader.load(SourceContent(temp_directory), recursive=True)
|
||||
|
||||
assert all(f"file{i}.txt" in result.content for i in range(1, 4))
|
||||
|
||||
def test_include_and_exclude_extensions(self, temp_directory):
|
||||
self._create_file(temp_directory, "a.txt")
|
||||
self._create_file(temp_directory, "b.py")
|
||||
self._create_file(temp_directory, "c.md")
|
||||
|
||||
loader = DirectoryLoader()
|
||||
result = loader.load(SourceContent(temp_directory), include_extensions=[".txt", ".py"])
|
||||
assert "a.txt" in result.content
|
||||
assert "b.py" in result.content
|
||||
assert "c.md" not in result.content
|
||||
|
||||
result2 = loader.load(SourceContent(temp_directory), exclude_extensions=[".py", ".md"])
|
||||
assert "a.txt" in result2.content
|
||||
assert "b.py" not in result2.content
|
||||
assert "c.md" not in result2.content
|
||||
|
||||
def test_max_files_limit(self, temp_directory):
|
||||
for i in range(5):
|
||||
self._create_file(temp_directory, f"file{i}.txt")
|
||||
|
||||
loader = DirectoryLoader()
|
||||
result = loader.load(SourceContent(temp_directory), max_files=3)
|
||||
|
||||
assert result.metadata["total_files"] == 3
|
||||
assert all(f"file{i}.txt" in result.content for i in range(3))
|
||||
|
||||
def test_hidden_files_and_dirs_excluded(self, temp_directory):
|
||||
self._create_file(temp_directory, "visible.txt", "visible")
|
||||
self._create_file(temp_directory, ".hidden.txt", "hidden")
|
||||
|
||||
hidden_dir = os.path.join(temp_directory, ".hidden")
|
||||
os.makedirs(hidden_dir)
|
||||
self._create_file(hidden_dir, "inside_hidden.txt")
|
||||
|
||||
loader = DirectoryLoader()
|
||||
result = loader.load(SourceContent(temp_directory), recursive=True)
|
||||
|
||||
assert "visible.txt" in result.content
|
||||
assert ".hidden.txt" not in result.content
|
||||
assert "inside_hidden.txt" not in result.content
|
||||
|
||||
def test_directory_does_not_exist(self):
|
||||
loader = DirectoryLoader()
|
||||
with pytest.raises(FileNotFoundError, match="Directory does not exist"):
|
||||
loader.load(SourceContent("/path/does/not/exist"))
|
||||
|
||||
def test_path_is_not_a_directory(self):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
loader = DirectoryLoader()
|
||||
with pytest.raises(ValueError, match="Path is not a directory"):
|
||||
loader.load(SourceContent(f.name))
|
||||
|
||||
def test_url_not_supported(self):
|
||||
loader = DirectoryLoader()
|
||||
with pytest.raises(ValueError, match="URL directory loading is not supported"):
|
||||
loader.load(SourceContent("https://example.com"))
|
||||
|
||||
def test_processing_error_handling(self, temp_directory):
|
||||
self._create_file(temp_directory, "valid.txt")
|
||||
error_file = self._create_file(temp_directory, "error.txt")
|
||||
|
||||
loader = DirectoryLoader()
|
||||
original_method = loader._process_single_file
|
||||
|
||||
def mock(file_path):
|
||||
if "error" in file_path:
|
||||
raise ValueError("Mock error")
|
||||
return original_method(file_path)
|
||||
|
||||
loader._process_single_file = mock
|
||||
result = loader.load(SourceContent(temp_directory))
|
||||
|
||||
assert "valid.txt" in result.content
|
||||
assert "error.txt (ERROR)" in result.content
|
||||
assert result.metadata["errors"] == 1
|
||||
assert len(result.metadata["error_details"]) == 1
|
||||
|
||||
def test_metadata_structure(self, temp_directory):
|
||||
self._create_file(temp_directory, "test.txt", "Sample")
|
||||
|
||||
loader = DirectoryLoader()
|
||||
result = loader.load(SourceContent(temp_directory))
|
||||
metadata = result.metadata
|
||||
|
||||
expected_keys = {
|
||||
"format", "directory_path", "total_files", "processed_files",
|
||||
"errors", "file_details", "error_details"
|
||||
}
|
||||
|
||||
assert expected_keys.issubset(metadata)
|
||||
assert all(k in metadata["file_details"][0] for k in ("path", "metadata", "source"))
|
||||
|
||||
def test_empty_directory(self, temp_directory):
|
||||
loader = DirectoryLoader()
|
||||
result = loader.load(SourceContent(temp_directory))
|
||||
|
||||
assert result.content == ""
|
||||
assert result.metadata["total_files"] == 0
|
||||
assert result.metadata["processed_files"] == 0
|
||||
135
tests/rag/test_docx_loader.py
Normal file
135
tests/rag/test_docx_loader.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
|
||||
from crewai_tools.rag.base_loader import LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class TestDOCXLoader:
|
||||
@patch('docx.Document')
|
||||
def test_load_docx_from_file(self, mock_docx_class):
|
||||
mock_doc = Mock()
|
||||
mock_doc.paragraphs = [
|
||||
Mock(text="First paragraph"),
|
||||
Mock(text="Second paragraph"),
|
||||
Mock(text=" ") # Blank paragraph
|
||||
]
|
||||
mock_doc.tables = []
|
||||
mock_docx_class.return_value = mock_doc
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
|
||||
loader = DOCXLoader()
|
||||
result = loader.load(SourceContent(f.name))
|
||||
|
||||
assert isinstance(result, LoaderResult)
|
||||
assert result.content == "First paragraph\nSecond paragraph"
|
||||
assert result.metadata == {"format": "docx", "paragraphs": 3, "tables": 0}
|
||||
assert result.source == f.name
|
||||
|
||||
@patch('docx.Document')
|
||||
def test_load_docx_with_tables(self, mock_docx_class):
|
||||
mock_doc = Mock()
|
||||
mock_doc.paragraphs = [Mock(text="Document with table")]
|
||||
mock_doc.tables = [Mock(), Mock()]
|
||||
mock_docx_class.return_value = mock_doc
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
|
||||
loader = DOCXLoader()
|
||||
result = loader.load(SourceContent(f.name))
|
||||
|
||||
assert result.metadata["tables"] == 2
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('docx.Document')
|
||||
@patch('tempfile.NamedTemporaryFile')
|
||||
@patch('os.unlink')
|
||||
def test_load_docx_from_url(self, mock_unlink, mock_tempfile, mock_docx_class, mock_get):
|
||||
mock_get.return_value = Mock(content=b"fake docx content", raise_for_status=Mock())
|
||||
|
||||
mock_temp = Mock(name="/tmp/temp_docx_file.docx")
|
||||
mock_temp.__enter__ = Mock(return_value=mock_temp)
|
||||
mock_temp.__exit__ = Mock(return_value=None)
|
||||
mock_tempfile.return_value = mock_temp
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc.paragraphs = [Mock(text="Content from URL")]
|
||||
mock_doc.tables = []
|
||||
mock_docx_class.return_value = mock_doc
|
||||
|
||||
loader = DOCXLoader()
|
||||
result = loader.load(SourceContent("https://example.com/test.docx"))
|
||||
|
||||
assert "Content from URL" in result.content
|
||||
assert result.source == "https://example.com/test.docx"
|
||||
|
||||
headers = mock_get.call_args[1]['headers']
|
||||
assert "application/vnd.openxmlformats-officedocument.wordprocessingml.document" in headers['Accept']
|
||||
assert "crewai-tools DOCXLoader" in headers['User-Agent']
|
||||
|
||||
mock_temp.write.assert_called_once_with(b"fake docx content")
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('docx.Document')
|
||||
def test_load_docx_from_url_with_custom_headers(self, mock_docx_class, mock_get):
|
||||
mock_get.return_value = Mock(content=b"fake docx content", raise_for_status=Mock())
|
||||
mock_docx_class.return_value = Mock(paragraphs=[], tables=[])
|
||||
|
||||
loader = DOCXLoader()
|
||||
custom_headers = {"Authorization": "Bearer token"}
|
||||
|
||||
with patch('tempfile.NamedTemporaryFile'), patch('os.unlink'):
|
||||
loader.load(SourceContent("https://example.com/test.docx"), headers=custom_headers)
|
||||
|
||||
assert mock_get.call_args[1]['headers'] == custom_headers
|
||||
|
||||
@patch('requests.get')
|
||||
def test_load_docx_url_download_error(self, mock_get):
|
||||
mock_get.side_effect = Exception("Network error")
|
||||
|
||||
loader = DOCXLoader()
|
||||
with pytest.raises(ValueError, match="Error fetching DOCX from URL"):
|
||||
loader.load(SourceContent("https://example.com/test.docx"))
|
||||
|
||||
@patch('requests.get')
|
||||
def test_load_docx_url_http_error(self, mock_get):
|
||||
mock_get.return_value = Mock(raise_for_status=Mock(side_effect=Exception("404 Not Found")))
|
||||
|
||||
loader = DOCXLoader()
|
||||
with pytest.raises(ValueError, match="Error fetching DOCX from URL"):
|
||||
loader.load(SourceContent("https://example.com/notfound.docx"))
|
||||
|
||||
def test_load_docx_invalid_source(self):
|
||||
loader = DOCXLoader()
|
||||
with pytest.raises(ValueError, match="Source must be a valid file path or URL"):
|
||||
loader.load(SourceContent("not_a_file_or_url"))
|
||||
|
||||
@patch('docx.Document')
|
||||
def test_load_docx_parsing_error(self, mock_docx_class):
|
||||
mock_docx_class.side_effect = Exception("Invalid DOCX file")
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
|
||||
loader = DOCXLoader()
|
||||
with pytest.raises(ValueError, match="Error loading DOCX file"):
|
||||
loader.load(SourceContent(f.name))
|
||||
|
||||
@patch('docx.Document')
|
||||
def test_load_docx_empty_document(self, mock_docx_class):
|
||||
mock_docx_class.return_value = Mock(paragraphs=[], tables=[])
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
|
||||
loader = DOCXLoader()
|
||||
result = loader.load(SourceContent(f.name))
|
||||
|
||||
assert result.content == ""
|
||||
assert result.metadata == {"paragraphs": 0, "tables": 0, "format": "docx"}
|
||||
|
||||
@patch('docx.Document')
|
||||
def test_docx_doc_id_generation(self, mock_docx_class):
|
||||
mock_docx_class.return_value = Mock(paragraphs=[Mock(text="Consistent content")], tables=[])
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
|
||||
loader = DOCXLoader()
|
||||
source = SourceContent(f.name)
|
||||
assert loader.load(source).doc_id == loader.load(source).doc_id
|
||||
180
tests/rag/test_json_loader.py
Normal file
180
tests/rag/test_json_loader.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from crewai_tools.rag.loaders.json_loader import JSONLoader
|
||||
from crewai_tools.rag.base_loader import LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class TestJSONLoader:
|
||||
def _create_temp_json_file(self, data) -> str:
|
||||
"""Helper to write JSON data to a temporary file and return its path."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(data, f)
|
||||
return f.name
|
||||
|
||||
def _create_temp_raw_file(self, content: str) -> str:
|
||||
"""Helper to write raw content to a temporary file and return its path."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
f.write(content)
|
||||
return f.name
|
||||
|
||||
def _load_from_path(self, path) -> LoaderResult:
|
||||
loader = JSONLoader()
|
||||
return loader.load(SourceContent(path))
|
||||
|
||||
def test_load_json_dict(self):
|
||||
path = self._create_temp_json_file({"name": "John", "age": 30, "items": ["a", "b", "c"]})
|
||||
try:
|
||||
result = self._load_from_path(path)
|
||||
assert isinstance(result, LoaderResult)
|
||||
assert all(k in result.content for k in ["name", "John", "age", "30"])
|
||||
assert result.metadata == {
|
||||
"format": "json", "type": "dict", "size": 3
|
||||
}
|
||||
assert result.source == path
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_load_json_list(self):
|
||||
path = self._create_temp_json_file([
|
||||
{"id": 1, "name": "Item 1"},
|
||||
{"id": 2, "name": "Item 2"},
|
||||
])
|
||||
try:
|
||||
result = self._load_from_path(path)
|
||||
assert result.metadata["type"] == "list"
|
||||
assert result.metadata["size"] == 2
|
||||
assert all(item in result.content for item in ["Item 1", "Item 2"])
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
@pytest.mark.parametrize("value, expected_type", [
|
||||
("simple string value", "str"),
|
||||
(42, "int"),
|
||||
])
|
||||
def test_load_json_primitives(self, value, expected_type):
|
||||
path = self._create_temp_json_file(value)
|
||||
try:
|
||||
result = self._load_from_path(path)
|
||||
assert result.metadata["type"] == expected_type
|
||||
assert result.metadata["size"] == 1
|
||||
assert str(value) in result.content
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_load_malformed_json(self):
|
||||
path = self._create_temp_raw_file('{"invalid": json,}')
|
||||
try:
|
||||
result = self._load_from_path(path)
|
||||
assert result.metadata["format"] == "json"
|
||||
assert "parse_error" in result.metadata
|
||||
assert result.content == '{"invalid": json,}'
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_load_empty_file(self):
|
||||
path = self._create_temp_raw_file('')
|
||||
try:
|
||||
result = self._load_from_path(path)
|
||||
assert "parse_error" in result.metadata
|
||||
assert result.content == ''
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_load_text_input(self):
|
||||
json_text = '{"message": "hello", "count": 5}'
|
||||
loader = JSONLoader()
|
||||
result = loader.load(SourceContent(json_text))
|
||||
assert all(part in result.content for part in ["message", "hello", "count", "5"])
|
||||
assert result.metadata["type"] == "dict"
|
||||
assert result.metadata["size"] == 2
|
||||
|
||||
def test_load_complex_nested_json(self):
|
||||
data = {
|
||||
"users": [
|
||||
{"id": 1, "profile": {"name": "Alice", "settings": {"theme": "dark"}}},
|
||||
{"id": 2, "profile": {"name": "Bob", "settings": {"theme": "light"}}}
|
||||
],
|
||||
"meta": {"total": 2, "version": "1.0"}
|
||||
}
|
||||
path = self._create_temp_json_file(data)
|
||||
try:
|
||||
result = self._load_from_path(path)
|
||||
for value in ["Alice", "Bob", "dark", "light"]:
|
||||
assert value in result.content
|
||||
assert result.metadata["size"] == 2 # top-level keys
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_consistent_doc_id(self):
|
||||
path = self._create_temp_json_file({"test": "data"})
|
||||
try:
|
||||
result1 = self._load_from_path(path)
|
||||
result2 = self._load_from_path(path)
|
||||
assert result1.doc_id == result2.doc_id
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
# ------------------------------
|
||||
# URL-based tests
|
||||
# ------------------------------
|
||||
|
||||
@patch('requests.get')
|
||||
def test_url_response_valid_json(self, mock_get):
|
||||
mock_get.return_value = Mock(
|
||||
text='{"key": "value", "number": 123}',
|
||||
json=Mock(return_value={"key": "value", "number": 123}),
|
||||
raise_for_status=Mock()
|
||||
)
|
||||
|
||||
loader = JSONLoader()
|
||||
result = loader.load(SourceContent("https://api.example.com/data.json"))
|
||||
|
||||
assert all(val in result.content for val in ["key", "value", "number", "123"])
|
||||
headers = mock_get.call_args[1]['headers']
|
||||
assert "application/json" in headers['Accept']
|
||||
assert "crewai-tools JSONLoader" in headers['User-Agent']
|
||||
|
||||
@patch('requests.get')
|
||||
def test_url_response_not_json(self, mock_get):
|
||||
mock_get.return_value = Mock(
|
||||
text='{"key": "value"}',
|
||||
json=Mock(side_effect=ValueError("Not JSON")),
|
||||
raise_for_status=Mock()
|
||||
)
|
||||
|
||||
loader = JSONLoader()
|
||||
result = loader.load(SourceContent("https://example.com/data.json"))
|
||||
assert all(part in result.content for part in ["key", "value"])
|
||||
|
||||
@patch('requests.get')
|
||||
def test_url_with_custom_headers(self, mock_get):
|
||||
mock_get.return_value = Mock(
|
||||
text='{"data": "test"}',
|
||||
json=Mock(return_value={"data": "test"}),
|
||||
raise_for_status=Mock()
|
||||
)
|
||||
headers = {"Authorization": "Bearer token", "Custom-Header": "value"}
|
||||
|
||||
loader = JSONLoader()
|
||||
loader.load(SourceContent("https://api.example.com/data.json"), headers=headers)
|
||||
|
||||
assert mock_get.call_args[1]['headers'] == headers
|
||||
|
||||
@patch('requests.get')
|
||||
def test_url_network_failure(self, mock_get):
|
||||
mock_get.side_effect = Exception("Network error")
|
||||
loader = JSONLoader()
|
||||
with pytest.raises(ValueError, match="Error fetching JSON from URL"):
|
||||
loader.load(SourceContent("https://api.example.com/data.json"))
|
||||
|
||||
@patch('requests.get')
|
||||
def test_url_http_error(self, mock_get):
|
||||
mock_get.return_value = Mock(raise_for_status=Mock(side_effect=Exception("404")))
|
||||
loader = JSONLoader()
|
||||
with pytest.raises(ValueError, match="Error fetching JSON from URL"):
|
||||
loader.load(SourceContent("https://api.example.com/404.json"))
|
||||
176
tests/rag/test_mdx_loader.py
Normal file
176
tests/rag/test_mdx_loader.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
|
||||
from crewai_tools.rag.base_loader import LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class TestMDXLoader:
|
||||
|
||||
def _write_temp_mdx(self, content):
|
||||
f = tempfile.NamedTemporaryFile(mode='w', suffix='.mdx', delete=False)
|
||||
f.write(content)
|
||||
f.close()
|
||||
return f.name
|
||||
|
||||
def _load_from_file(self, content):
|
||||
path = self._write_temp_mdx(content)
|
||||
try:
|
||||
loader = MDXLoader()
|
||||
return loader.load(SourceContent(path)), path
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_load_basic_mdx_file(self):
|
||||
content = """
|
||||
import Component from './Component'
|
||||
export const meta = { title: 'Test' }
|
||||
|
||||
# Test MDX File
|
||||
|
||||
This is a **markdown** file with JSX.
|
||||
|
||||
<Component prop="value" />
|
||||
|
||||
Some more content.
|
||||
|
||||
<div className="container">
|
||||
<p>Nested content</p>
|
||||
</div>
|
||||
"""
|
||||
result, path = self._load_from_file(content)
|
||||
|
||||
assert isinstance(result, LoaderResult)
|
||||
assert all(tag not in result.content for tag in ["import", "export", "<Component", "<div", "</div>"])
|
||||
assert all(text in result.content for text in ["# Test MDX File", "markdown", "Some more content", "Nested content"])
|
||||
assert result.metadata["format"] == "mdx"
|
||||
assert result.source == path
|
||||
|
||||
def test_mdx_multiple_imports_exports(self):
|
||||
content = """
|
||||
import React from 'react'
|
||||
import { useState } from 'react'
|
||||
import CustomComponent from './custom'
|
||||
|
||||
export default function Layout() { return null }
|
||||
export const config = { test: true }
|
||||
|
||||
# Content
|
||||
|
||||
Regular markdown content here.
|
||||
"""
|
||||
result, _ = self._load_from_file(content)
|
||||
assert "# Content" in result.content
|
||||
assert "Regular markdown content here." in result.content
|
||||
assert "import" not in result.content and "export" not in result.content
|
||||
|
||||
def test_complex_jsx_cleanup(self):
|
||||
content = """
|
||||
# MDX with Complex JSX
|
||||
|
||||
<div className="alert alert-info">
|
||||
<strong>Info:</strong> This is important information.
|
||||
<ul><li>Item 1</li><li>Item 2</li></ul>
|
||||
</div>
|
||||
|
||||
Regular paragraph text.
|
||||
|
||||
<MyComponent prop1="value1">Nested content inside component</MyComponent>
|
||||
"""
|
||||
result, _ = self._load_from_file(content)
|
||||
assert all(tag not in result.content for tag in ["<div", "<strong>", "<ul>", "<MyComponent"])
|
||||
assert all(text in result.content for text in ["Info:", "Item 1", "Regular paragraph text.", "Nested content inside component"])
|
||||
|
||||
def test_whitespace_cleanup(self):
|
||||
content = """
|
||||
|
||||
|
||||
# Title
|
||||
|
||||
|
||||
Some content.
|
||||
|
||||
|
||||
More content after multiple newlines.
|
||||
|
||||
|
||||
|
||||
Final content.
|
||||
"""
|
||||
result, _ = self._load_from_file(content)
|
||||
assert result.content.count('\n\n\n') == 0
|
||||
assert result.content.startswith('# Title')
|
||||
assert result.content.endswith('Final content.')
|
||||
|
||||
def test_only_jsx_content(self):
|
||||
content = """
|
||||
<div>
|
||||
<h1>Only JSX content</h1>
|
||||
<p>No markdown here</p>
|
||||
</div>
|
||||
"""
|
||||
result, _ = self._load_from_file(content)
|
||||
assert all(tag not in result.content for tag in ["<div>", "<h1>", "<p>"])
|
||||
assert "Only JSX content" in result.content
|
||||
assert "No markdown here" in result.content
|
||||
|
||||
@patch('requests.get')
|
||||
def test_load_mdx_from_url(self, mock_get):
|
||||
mock_get.return_value = Mock(text="# MDX from URL\n\nContent here.\n\n<Component />", raise_for_status=lambda: None)
|
||||
loader = MDXLoader()
|
||||
result = loader.load(SourceContent("https://example.com/content.mdx"))
|
||||
assert "# MDX from URL" in result.content
|
||||
assert "<Component />" not in result.content
|
||||
|
||||
@patch('requests.get')
|
||||
def test_load_mdx_with_custom_headers(self, mock_get):
|
||||
mock_get.return_value = Mock(text="# Custom headers test", raise_for_status=lambda: None)
|
||||
loader = MDXLoader()
|
||||
loader.load(SourceContent("https://example.com"), headers={"Authorization": "Bearer token"})
|
||||
assert mock_get.call_args[1]['headers'] == {"Authorization": "Bearer token"}
|
||||
|
||||
@patch('requests.get')
|
||||
def test_mdx_url_fetch_error(self, mock_get):
|
||||
mock_get.side_effect = Exception("Network error")
|
||||
with pytest.raises(ValueError, match="Error fetching MDX from URL"):
|
||||
MDXLoader().load(SourceContent("https://example.com"))
|
||||
|
||||
def test_load_inline_mdx_text(self):
|
||||
content = """# Inline MDX\n\nimport Something from 'somewhere'\n\nContent with <Component prop=\"value\" />.\n\nexport const meta = { title: 'Test' }"""
|
||||
loader = MDXLoader()
|
||||
result = loader.load(SourceContent(content))
|
||||
assert "# Inline MDX" in result.content
|
||||
assert "Content with ." in result.content
|
||||
|
||||
def test_empty_result_after_cleaning(self):
|
||||
content = """
|
||||
import Something from 'somewhere'
|
||||
export const config = {}
|
||||
<div></div>
|
||||
"""
|
||||
result, _ = self._load_from_file(content)
|
||||
assert result.content.strip() == ""
|
||||
|
||||
def test_edge_case_parsing(self):
|
||||
content = """
|
||||
# Title
|
||||
|
||||
<Component>
|
||||
Multi-line
|
||||
JSX content
|
||||
</Component>
|
||||
|
||||
import { a, b } from 'module'
|
||||
|
||||
export { x, y }
|
||||
|
||||
Final text.
|
||||
"""
|
||||
result, _ = self._load_from_file(content)
|
||||
assert "# Title" in result.content
|
||||
assert "JSX content" in result.content
|
||||
assert "Final text." in result.content
|
||||
assert all(phrase not in result.content for phrase in ["import {", "export {", "<Component>"])
|
||||
160
tests/rag/test_text_loaders.py
Normal file
160
tests/rag/test_text_loaders.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import hashlib
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
|
||||
from crewai_tools.rag.base_loader import LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
def write_temp_file(content, suffix=".txt", encoding="utf-8"):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False, encoding=encoding) as f:
|
||||
f.write(content)
|
||||
return f.name
|
||||
|
||||
|
||||
def cleanup_temp_file(path):
|
||||
try:
|
||||
os.unlink(path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
class TestTextFileLoader:
|
||||
def test_basic_text_file(self):
|
||||
content = "This is test content\nWith multiple lines\nAnd more text"
|
||||
path = write_temp_file(content)
|
||||
try:
|
||||
result = TextFileLoader().load(SourceContent(path))
|
||||
assert isinstance(result, LoaderResult)
|
||||
assert result.content == content
|
||||
assert result.source == path
|
||||
assert result.doc_id
|
||||
assert result.metadata in (None, {})
|
||||
finally:
|
||||
cleanup_temp_file(path)
|
||||
|
||||
def test_empty_file(self):
|
||||
path = write_temp_file("")
|
||||
try:
|
||||
result = TextFileLoader().load(SourceContent(path))
|
||||
assert result.content == ""
|
||||
finally:
|
||||
cleanup_temp_file(path)
|
||||
|
||||
def test_unicode_content(self):
|
||||
content = "Hello 世界 🌍 émojis 🎉 åäö"
|
||||
path = write_temp_file(content)
|
||||
try:
|
||||
result = TextFileLoader().load(SourceContent(path))
|
||||
assert content in result.content
|
||||
finally:
|
||||
cleanup_temp_file(path)
|
||||
|
||||
def test_large_file(self):
|
||||
content = "\n".join(f"Line {i}" for i in range(100))
|
||||
path = write_temp_file(content)
|
||||
try:
|
||||
result = TextFileLoader().load(SourceContent(path))
|
||||
assert "Line 0" in result.content
|
||||
assert "Line 99" in result.content
|
||||
assert result.content.count("\n") == 99
|
||||
finally:
|
||||
cleanup_temp_file(path)
|
||||
|
||||
def test_missing_file(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
TextFileLoader().load(SourceContent("/nonexistent/path.txt"))
|
||||
|
||||
def test_permission_denied(self):
|
||||
path = write_temp_file("Some content")
|
||||
os.chmod(path, 0o000)
|
||||
try:
|
||||
with pytest.raises(PermissionError):
|
||||
TextFileLoader().load(SourceContent(path))
|
||||
finally:
|
||||
os.chmod(path, 0o644)
|
||||
cleanup_temp_file(path)
|
||||
|
||||
def test_doc_id_consistency(self):
|
||||
content = "Consistent content"
|
||||
path = write_temp_file(content)
|
||||
try:
|
||||
loader = TextFileLoader()
|
||||
result1 = loader.load(SourceContent(path))
|
||||
result2 = loader.load(SourceContent(path))
|
||||
expected_id = hashlib.sha256((path + content).encode("utf-8")).hexdigest()
|
||||
assert result1.doc_id == result2.doc_id == expected_id
|
||||
finally:
|
||||
cleanup_temp_file(path)
|
||||
|
||||
def test_various_extensions(self):
|
||||
content = "Same content"
|
||||
for ext in [".txt", ".md", ".log", ".json"]:
|
||||
path = write_temp_file(content, suffix=ext)
|
||||
try:
|
||||
result = TextFileLoader().load(SourceContent(path))
|
||||
assert result.content == content
|
||||
finally:
|
||||
cleanup_temp_file(path)
|
||||
|
||||
|
||||
class TestTextLoader:
|
||||
def test_basic_text(self):
|
||||
content = "Raw text"
|
||||
result = TextLoader().load(SourceContent(content))
|
||||
expected_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
assert result.content == content
|
||||
assert result.source == expected_hash
|
||||
assert result.doc_id == expected_hash
|
||||
|
||||
def test_multiline_text(self):
|
||||
content = "Line 1\nLine 2\nLine 3"
|
||||
result = TextLoader().load(SourceContent(content))
|
||||
assert "Line 2" in result.content
|
||||
|
||||
def test_empty_text(self):
|
||||
result = TextLoader().load(SourceContent(""))
|
||||
assert result.content == ""
|
||||
assert result.source == hashlib.sha256("".encode("utf-8")).hexdigest()
|
||||
|
||||
def test_unicode_text(self):
|
||||
content = "世界 🌍 émojis 🎉 åäö"
|
||||
result = TextLoader().load(SourceContent(content))
|
||||
assert content in result.content
|
||||
|
||||
def test_special_characters(self):
|
||||
content = "!@#$$%^&*()_+-=~`{}[]\\|;:'\",.<>/?"
|
||||
result = TextLoader().load(SourceContent(content))
|
||||
assert result.content == content
|
||||
|
||||
def test_doc_id_uniqueness(self):
|
||||
result1 = TextLoader().load(SourceContent("A"))
|
||||
result2 = TextLoader().load(SourceContent("B"))
|
||||
assert result1.doc_id != result2.doc_id
|
||||
|
||||
def test_whitespace_text(self):
|
||||
content = " \n\t "
|
||||
result = TextLoader().load(SourceContent(content))
|
||||
assert result.content == content
|
||||
|
||||
def test_long_text(self):
|
||||
content = "A" * 10000
|
||||
result = TextLoader().load(SourceContent(content))
|
||||
assert len(result.content) == 10000
|
||||
|
||||
|
||||
class TestTextLoadersIntegration:
|
||||
def test_consistency_between_loaders(self):
|
||||
content = "Consistent content"
|
||||
text_result = TextLoader().load(SourceContent(content))
|
||||
file_path = write_temp_file(content)
|
||||
try:
|
||||
file_result = TextFileLoader().load(SourceContent(file_path))
|
||||
|
||||
assert text_result.content == file_result.content
|
||||
assert text_result.source != file_result.source
|
||||
assert text_result.doc_id != file_result.doc_id
|
||||
finally:
|
||||
cleanup_temp_file(file_path)
|
||||
137
tests/rag/test_webpage_loader.py
Normal file
137
tests/rag/test_webpage_loader.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
|
||||
from crewai_tools.rag.base_loader import LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class TestWebPageLoader:
|
||||
def setup_mock_response(self, text, status_code=200, content_type="text/html"):
|
||||
response = Mock()
|
||||
response.text = text
|
||||
response.apparent_encoding = "utf-8"
|
||||
response.status_code = status_code
|
||||
response.headers = {"content-type": content_type}
|
||||
return response
|
||||
|
||||
def setup_mock_soup(self, text, title=None, script_style_elements=None):
|
||||
soup = Mock()
|
||||
soup.get_text.return_value = text
|
||||
soup.title = Mock(string=title) if title is not None else None
|
||||
soup.return_value = script_style_elements or []
|
||||
return soup
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_load_basic_webpage(self, mock_bs, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response("<html><head><title>Test Page</title></head><body><p>Test content</p></body></html>")
|
||||
mock_bs.return_value = self.setup_mock_soup("Test content", title="Test Page")
|
||||
|
||||
loader = WebPageLoader()
|
||||
result = loader.load(SourceContent("https://example.com"))
|
||||
|
||||
assert isinstance(result, LoaderResult)
|
||||
assert result.content == "Test content"
|
||||
assert result.metadata["title"] == "Test Page"
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get):
|
||||
html = """
|
||||
<html><head><title>Page with Scripts</title><style>body { color: red; }</style></head>
|
||||
<body><script>console.log('test');</script><p>Visible content</p></body></html>
|
||||
"""
|
||||
mock_get.return_value = self.setup_mock_response(html)
|
||||
scripts = [Mock(), Mock()]
|
||||
styles = [Mock()]
|
||||
for el in scripts + styles:
|
||||
el.decompose = Mock()
|
||||
mock_bs.return_value = self.setup_mock_soup("Page with Scripts Visible content", title="Page with Scripts", script_style_elements=scripts + styles)
|
||||
|
||||
loader = WebPageLoader()
|
||||
result = loader.load(SourceContent("https://example.com/with-scripts"))
|
||||
|
||||
assert "Visible content" in result.content
|
||||
for el in scripts + styles:
|
||||
el.decompose.assert_called_once()
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_text_cleaning_and_title_handling(self, mock_bs, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response("<html><body><p> Messy text </p></body></html>")
|
||||
mock_bs.return_value = self.setup_mock_soup("Text with extra spaces\n\n More\t text \n\n", title=None)
|
||||
|
||||
loader = WebPageLoader()
|
||||
result = loader.load(SourceContent("https://example.com/messy-text"))
|
||||
assert result.content is not None
|
||||
assert result.metadata["title"] == ""
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_empty_or_missing_title(self, mock_bs, mock_get):
|
||||
for title in [None, ""]:
|
||||
mock_get.return_value = self.setup_mock_response("<html><head><title></title></head><body>Content</body></html>")
|
||||
mock_bs.return_value = self.setup_mock_soup("Content", title=title)
|
||||
|
||||
loader = WebPageLoader()
|
||||
result = loader.load(SourceContent("https://example.com"))
|
||||
assert result.metadata["title"] == ""
|
||||
|
||||
@patch('requests.get')
|
||||
def test_custom_and_default_headers(self, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response("<html><body>Test</body></html>")
|
||||
custom_headers = {"User-Agent": "Bot", "Authorization": "Bearer xyz", "Accept": "text/html"}
|
||||
|
||||
with patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') as mock_bs:
|
||||
mock_bs.return_value = self.setup_mock_soup("Test")
|
||||
WebPageLoader().load(SourceContent("https://example.com"), headers=custom_headers)
|
||||
|
||||
assert mock_get.call_args[1]['headers'] == custom_headers
|
||||
|
||||
@patch('requests.get')
|
||||
def test_error_handling(self, mock_get):
|
||||
for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]:
|
||||
mock_get.side_effect = error
|
||||
with pytest.raises(ValueError, match="Error loading webpage"):
|
||||
WebPageLoader().load(SourceContent("https://example.com"))
|
||||
|
||||
@patch('requests.get')
|
||||
def test_timeout_and_http_error(self, mock_get):
|
||||
import requests
|
||||
mock_get.side_effect = requests.Timeout("Timeout")
|
||||
with pytest.raises(ValueError):
|
||||
WebPageLoader().load(SourceContent("https://example.com"))
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404")
|
||||
mock_get.side_effect = None
|
||||
mock_get.return_value = mock_response
|
||||
with pytest.raises(ValueError):
|
||||
WebPageLoader().load(SourceContent("https://example.com/404"))
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_doc_id_consistency(self, mock_bs, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response("<html><body>Doc</body></html>")
|
||||
mock_bs.return_value = self.setup_mock_soup("Doc")
|
||||
|
||||
loader = WebPageLoader()
|
||||
result1 = loader.load(SourceContent("https://example.com"))
|
||||
result2 = loader.load(SourceContent("https://example.com"))
|
||||
|
||||
assert result1.doc_id == result2.doc_id
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_status_code_and_content_type(self, mock_bs, mock_get):
|
||||
for status in [200, 201, 301]:
|
||||
mock_get.return_value = self.setup_mock_response(f"<html><body>Status {status}</body></html>", status_code=status)
|
||||
mock_bs.return_value = self.setup_mock_soup(f"Status {status}")
|
||||
result = WebPageLoader().load(SourceContent(f"https://example.com/{status}"))
|
||||
assert result.metadata["status_code"] == status
|
||||
|
||||
for ctype in ["text/html", "text/plain", "application/xhtml+xml"]:
|
||||
mock_get.return_value = self.setup_mock_response("<html><body>Content</body></html>", content_type=ctype)
|
||||
mock_bs.return_value = self.setup_mock_soup("Content")
|
||||
result = WebPageLoader().load(SourceContent("https://example.com"))
|
||||
assert result.metadata["content_type"] == ctype
|
||||
137
tests/rag/test_xml_loader.py
Normal file
137
tests/rag/test_xml_loader.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
|
||||
from crewai_tools.rag.base_loader import LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class TestWebPageLoader:
|
||||
def setup_mock_response(self, text, status_code=200, content_type="text/html"):
|
||||
response = Mock()
|
||||
response.text = text
|
||||
response.apparent_encoding = "utf-8"
|
||||
response.status_code = status_code
|
||||
response.headers = {"content-type": content_type}
|
||||
return response
|
||||
|
||||
def setup_mock_soup(self, text, title=None, script_style_elements=None):
|
||||
soup = Mock()
|
||||
soup.get_text.return_value = text
|
||||
soup.title = Mock(string=title) if title is not None else None
|
||||
soup.return_value = script_style_elements or []
|
||||
return soup
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_load_basic_webpage(self, mock_bs, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response("<html><head><title>Test Page</title></head><body><p>Test content</p></body></html>")
|
||||
mock_bs.return_value = self.setup_mock_soup("Test content", title="Test Page")
|
||||
|
||||
loader = WebPageLoader()
|
||||
result = loader.load(SourceContent("https://example.com"))
|
||||
|
||||
assert isinstance(result, LoaderResult)
|
||||
assert result.content == "Test content"
|
||||
assert result.metadata["title"] == "Test Page"
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get):
|
||||
html = """
|
||||
<html><head><title>Page with Scripts</title><style>body { color: red; }</style></head>
|
||||
<body><script>console.log('test');</script><p>Visible content</p></body></html>
|
||||
"""
|
||||
mock_get.return_value = self.setup_mock_response(html)
|
||||
scripts = [Mock(), Mock()]
|
||||
styles = [Mock()]
|
||||
for el in scripts + styles:
|
||||
el.decompose = Mock()
|
||||
mock_bs.return_value = self.setup_mock_soup("Page with Scripts Visible content", title="Page with Scripts", script_style_elements=scripts + styles)
|
||||
|
||||
loader = WebPageLoader()
|
||||
result = loader.load(SourceContent("https://example.com/with-scripts"))
|
||||
|
||||
assert "Visible content" in result.content
|
||||
for el in scripts + styles:
|
||||
el.decompose.assert_called_once()
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_text_cleaning_and_title_handling(self, mock_bs, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response("<html><body><p> Messy text </p></body></html>")
|
||||
mock_bs.return_value = self.setup_mock_soup("Text with extra spaces\n\n More\t text \n\n", title=None)
|
||||
|
||||
loader = WebPageLoader()
|
||||
result = loader.load(SourceContent("https://example.com/messy-text"))
|
||||
assert result.content is not None
|
||||
assert result.metadata["title"] == ""
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_empty_or_missing_title(self, mock_bs, mock_get):
|
||||
for title in [None, ""]:
|
||||
mock_get.return_value = self.setup_mock_response("<html><head><title></title></head><body>Content</body></html>")
|
||||
mock_bs.return_value = self.setup_mock_soup("Content", title=title)
|
||||
|
||||
loader = WebPageLoader()
|
||||
result = loader.load(SourceContent("https://example.com"))
|
||||
assert result.metadata["title"] == ""
|
||||
|
||||
@patch('requests.get')
|
||||
def test_custom_and_default_headers(self, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response("<html><body>Test</body></html>")
|
||||
custom_headers = {"User-Agent": "Bot", "Authorization": "Bearer xyz", "Accept": "text/html"}
|
||||
|
||||
with patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') as mock_bs:
|
||||
mock_bs.return_value = self.setup_mock_soup("Test")
|
||||
WebPageLoader().load(SourceContent("https://example.com"), headers=custom_headers)
|
||||
|
||||
assert mock_get.call_args[1]['headers'] == custom_headers
|
||||
|
||||
@patch('requests.get')
|
||||
def test_error_handling(self, mock_get):
|
||||
for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]:
|
||||
mock_get.side_effect = error
|
||||
with pytest.raises(ValueError, match="Error loading webpage"):
|
||||
WebPageLoader().load(SourceContent("https://example.com"))
|
||||
|
||||
@patch('requests.get')
|
||||
def test_timeout_and_http_error(self, mock_get):
|
||||
import requests
|
||||
mock_get.side_effect = requests.Timeout("Timeout")
|
||||
with pytest.raises(ValueError):
|
||||
WebPageLoader().load(SourceContent("https://example.com"))
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404")
|
||||
mock_get.side_effect = None
|
||||
mock_get.return_value = mock_response
|
||||
with pytest.raises(ValueError):
|
||||
WebPageLoader().load(SourceContent("https://example.com/404"))
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_doc_id_consistency(self, mock_bs, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response("<html><body>Doc</body></html>")
|
||||
mock_bs.return_value = self.setup_mock_soup("Doc")
|
||||
|
||||
loader = WebPageLoader()
|
||||
result1 = loader.load(SourceContent("https://example.com"))
|
||||
result2 = loader.load(SourceContent("https://example.com"))
|
||||
|
||||
assert result1.doc_id == result2.doc_id
|
||||
|
||||
@patch('requests.get')
|
||||
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
|
||||
def test_status_code_and_content_type(self, mock_bs, mock_get):
|
||||
for status in [200, 201, 301]:
|
||||
mock_get.return_value = self.setup_mock_response(f"<html><body>Status {status}</body></html>", status_code=status)
|
||||
mock_bs.return_value = self.setup_mock_soup(f"Status {status}")
|
||||
result = WebPageLoader().load(SourceContent(f"https://example.com/{status}"))
|
||||
assert result.metadata["status_code"] == status
|
||||
|
||||
for ctype in ["text/html", "text/plain", "application/xhtml+xml"]:
|
||||
mock_get.return_value = self.setup_mock_response("<html><body>Content</body></html>", content_type=ctype)
|
||||
mock_bs.return_value = self.setup_mock_soup("Content")
|
||||
result = WebPageLoader().load(SourceContent("https://example.com"))
|
||||
assert result.metadata["content_type"] == ctype
|
||||
Reference in New Issue
Block a user