Adds RAG feature (#406)

* feat: initialize rag

* refactor: using cosine distance metric for chromadb

* feat: use RecursiveCharacterTextSplitter as chunker strategy

* feat: support chucker and loader per data_type

* feat: adding JSON loader

* feat: adding CSVLoader

* feat: adding loader for DOCX files

* feat: add loader for MDX files

* feat: add loader for XML files

* feat: add loader for parser Webpage

* feat: support to load files from an entire directory

* feat: support to auto-load the loaders for additional DataType

* feat: add chuckers for some specific data type

- Each chunker uses separators specific to its content type

* feat: prevent document duplication and centralize content management

- Implement document deduplication logic in RAG
  * Check for existing documents by source reference
  * Compare doc IDs to detect content changes
  * Automatically replace outdated content while preventing duplicates

- Centralize common functionality for better maintainability
  * Create SourceContent class to handle URLs, files, and text uniformly
  * Extract shared utilities (compute_sha256) to misc.py
  * Standardize doc ID generation across all loaders

- Improve RAG system architecture
  * All loaders now inherit consistent patterns from centralized BaseLoader
  * Better separation of concerns with dedicated content management classes
  * Standardized LoaderResult structure across all loader implementations

* chore: split text loaders file

* test: adding missing tests about RAG loaders

* refactor: QOL

* fix: add missing uv syntax on DOCXLoader
This commit is contained in:
Lucas Gomide
2025-08-19 19:30:35 -03:00
committed by GitHub
parent 1ce016df8b
commit dc039cfac8
31 changed files with 2595 additions and 0 deletions

View File

@@ -0,0 +1,41 @@
from typing import Any, Optional
from crewai_tools.rag.core import RAG
from crewai_tools.tools.rag.rag_tool import Adapter
class RAGAdapter(Adapter):
def __init__(
self,
collection_name: str = "crewai_knowledge_base",
persist_directory: Optional[str] = None,
embedding_model: str = "text-embedding-3-small",
top_k: int = 5,
embedding_api_key: Optional[str] = None,
**embedding_kwargs
):
super().__init__()
# Prepare embedding configuration
embedding_config = {
"api_key": embedding_api_key,
**embedding_kwargs
}
self._adapter = RAG(
collection_name=collection_name,
persist_directory=persist_directory,
embedding_model=embedding_model,
top_k=top_k,
embedding_config=embedding_config
)
def query(self, question: str) -> str:
return self._adapter.query(question)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self._adapter.add(*args, **kwargs)

View File

@@ -0,0 +1,8 @@
from crewai_tools.rag.core import RAG, EmbeddingService
from crewai_tools.rag.data_types import DataType
__all__ = [
"RAG",
"EmbeddingService",
"DataType",
]

View File

@@ -0,0 +1,37 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from crewai_tools.rag.misc import compute_sha256
from crewai_tools.rag.source_content import SourceContent
class LoaderResult(BaseModel):
content: str = Field(description="The text content of the source")
source: str = Field(description="The source of the content", default="unknown")
metadata: Dict[str, Any] = Field(description="The metadata of the source", default_factory=dict)
doc_id: str = Field(description="The id of the document")
class BaseLoader(ABC):
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}
@abstractmethod
def load(self, content: SourceContent, **kwargs) -> LoaderResult:
...
def generate_doc_id(self, source_ref: str | None = None, content: str | None = None) -> str:
"""
Generate a unique document id based on the source reference and content.
If the source reference is not provided, the content is used as the source reference.
If the content is not provided, the source reference is used as the content.
If both are provided, the source reference is used as the content.
Both are optional because the TEXT content type does not have a source reference. In this case, the content is used as the source reference.
"""
source_ref = source_ref or ""
content = content or ""
return compute_sha256(source_ref + content)

View File

@@ -0,0 +1,15 @@
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from crewai_tools.rag.chunkers.default_chunker import DefaultChunker
from crewai_tools.rag.chunkers.text_chunker import TextChunker, DocxChunker, MdxChunker
from crewai_tools.rag.chunkers.structured_chunker import CsvChunker, JsonChunker, XmlChunker
__all__ = [
"BaseChunker",
"DefaultChunker",
"TextChunker",
"DocxChunker",
"MdxChunker",
"CsvChunker",
"JsonChunker",
"XmlChunker",
]

View File

@@ -0,0 +1,167 @@
from typing import List, Optional
import re
class RecursiveCharacterTextSplitter:
"""
A text splitter that recursively splits text based on a hierarchy of separators.
"""
def __init__(
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
):
"""
Initialize the RecursiveCharacterTextSplitter.
Args:
chunk_size: Maximum size of each chunk
chunk_overlap: Number of characters to overlap between chunks
separators: List of separators to use for splitting (in order of preference)
keep_separator: Whether to keep the separator in the split text
"""
if chunk_overlap >= chunk_size:
raise ValueError(f"Chunk overlap ({chunk_overlap}) cannot be >= chunk size ({chunk_size})")
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._keep_separator = keep_separator
self._separators = separators or [
"\n\n",
"\n",
" ",
"",
]
def split_text(self, text: str) -> List[str]:
return self._split_text(text, self._separators)
def _split_text(self, text: str, separators: List[str]) -> List[str]:
separator = separators[-1]
new_separators = []
for i, sep in enumerate(separators):
if sep == "":
separator = sep
break
if re.search(re.escape(sep), text):
separator = sep
new_separators = separators[i + 1:]
break
splits = self._split_text_with_separator(text, separator)
good_splits = []
for split in splits:
if len(split) < self._chunk_size:
good_splits.append(split)
else:
if new_separators:
other_info = self._split_text(split, new_separators)
good_splits.extend(other_info)
else:
good_splits.extend(self._split_by_characters(split))
return self._merge_splits(good_splits, separator)
def _split_text_with_separator(self, text: str, separator: str) -> List[str]:
if separator == "":
return list(text)
if self._keep_separator and separator in text:
parts = text.split(separator)
splits = []
for i, part in enumerate(parts):
if i == 0:
splits.append(part)
elif i == len(parts) - 1:
if part:
splits.append(separator + part)
else:
if part:
splits.append(separator + part)
else:
if splits:
splits[-1] += separator
return [s for s in splits if s]
else:
return text.split(separator)
def _split_by_characters(self, text: str) -> List[str]:
chunks = []
for i in range(0, len(text), self._chunk_size):
chunks.append(text[i:i + self._chunk_size])
return chunks
def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
"""Merge splits into chunks with proper overlap."""
docs = []
current_doc = []
total = 0
for split in splits:
split_len = len(split)
if total + split_len > self._chunk_size and current_doc:
if separator == "":
doc = "".join(current_doc)
else:
doc = separator.join(current_doc)
if doc:
docs.append(doc)
# Handle overlap by keeping some of the previous content
while total > self._chunk_overlap and len(current_doc) > 1:
removed = current_doc.pop(0)
total -= len(removed)
if separator != "":
total -= len(separator)
current_doc.append(split)
total += split_len
if separator != "" and len(current_doc) > 1:
total += len(separator)
if current_doc:
if separator == "":
doc = "".join(current_doc)
else:
doc = separator.join(current_doc)
if doc:
docs.append(doc)
return docs
class BaseChunker:
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True):
"""
Initialize the Chunker
Args:
chunk_size: Maximum size of each chunk
chunk_overlap: Number of characters to overlap between chunks
separators: List of separators to use for splitting
keep_separator: Whether to keep separators in the chunks
"""
self._splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=separators,
keep_separator=keep_separator,
)
def chunk(self, text: str) -> List[str]:
if not text or not text.strip():
return []
return self._splitter.split_text(text)

View File

@@ -0,0 +1,6 @@
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from typing import List, Optional
class DefaultChunker(BaseChunker):
def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 20, separators: Optional[List[str]] = None, keep_separator: bool = True):
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)

View File

@@ -0,0 +1,49 @@
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from typing import List, Optional
class CsvChunker(BaseChunker):
def __init__(self, chunk_size: int = 1200, chunk_overlap: int = 100, separators: Optional[List[str]] = None, keep_separator: bool = True):
if separators is None:
separators = [
"\nRow ", # Row boundaries (from CSVLoader format)
"\n", # Line breaks
" | ", # Column separators
", ", # Comma separators
" ", # Word breaks
"", # Character level
]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
class JsonChunker(BaseChunker):
def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True):
if separators is None:
separators = [
"\n\n", # Object/array boundaries
"\n", # Line breaks
"},", # Object endings
"],", # Array endings
", ", # Property separators
": ", # Key-value separators
" ", # Word breaks
"", # Character level
]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
class XmlChunker(BaseChunker):
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
if separators is None:
separators = [
"\n\n", # Element boundaries
"\n", # Line breaks
">", # Tag endings
". ", # Sentence endings (for text content)
"! ", # Exclamation endings
"? ", # Question endings
", ", # Comma separators
" ", # Word breaks
"", # Character level
]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)

View File

@@ -0,0 +1,59 @@
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from typing import List, Optional
class TextChunker(BaseChunker):
def __init__(self, chunk_size: int = 1500, chunk_overlap: int = 150, separators: Optional[List[str]] = None, keep_separator: bool = True):
if separators is None:
separators = [
"\n\n\n", # Multiple line breaks (sections)
"\n\n", # Paragraph breaks
"\n", # Line breaks
". ", # Sentence endings
"! ", # Exclamation endings
"? ", # Question endings
"; ", # Semicolon breaks
", ", # Comma breaks
" ", # Word breaks
"", # Character level
]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
class DocxChunker(BaseChunker):
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
if separators is None:
separators = [
"\n\n\n", # Multiple line breaks (major sections)
"\n\n", # Paragraph breaks
"\n", # Line breaks
". ", # Sentence endings
"! ", # Exclamation endings
"? ", # Question endings
"; ", # Semicolon breaks
", ", # Comma breaks
" ", # Word breaks
"", # Character level
]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
class MdxChunker(BaseChunker):
def __init__(self, chunk_size: int = 3000, chunk_overlap: int = 300, separators: Optional[List[str]] = None, keep_separator: bool = True):
if separators is None:
separators = [
"\n## ", # H2 headers (major sections)
"\n### ", # H3 headers (subsections)
"\n#### ", # H4 headers (sub-subsections)
"\n\n", # Paragraph breaks
"\n```", # Code block boundaries
"\n", # Line breaks
". ", # Sentence endings
"! ", # Exclamation endings
"? ", # Question endings
"; ", # Semicolon breaks
", ", # Comma breaks
" ", # Word breaks
"", # Character level
]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)

View File

@@ -0,0 +1,20 @@
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from typing import List, Optional
class WebsiteChunker(BaseChunker):
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
if separators is None:
separators = [
"\n\n\n", # Major section breaks
"\n\n", # Paragraph breaks
"\n", # Line breaks
". ", # Sentence endings
"! ", # Exclamation endings
"? ", # Question endings
"; ", # Semicolon breaks
", ", # Comma breaks
" ", # Word breaks
"", # Character level
]
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)

View File

@@ -0,0 +1,232 @@
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4
import chromadb
import litellm
from pydantic import BaseModel, Field, PrivateAttr
from crewai_tools.tools.rag.rag_tool import Adapter
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.base_loader import BaseLoader
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.rag.misc import compute_sha256
logger = logging.getLogger(__name__)
class EmbeddingService:
def __init__(self, model: str = "text-embedding-3-small", **kwargs):
self.model = model
self.kwargs = kwargs
def embed_text(self, text: str) -> List[float]:
try:
response = litellm.embedding(
model=self.model,
input=[text],
**self.kwargs
)
return response.data[0]['embedding']
except Exception as e:
logger.error(f"Error generating embedding: {e}")
raise
def embed_batch(self, texts: List[str]) -> List[List[float]]:
if not texts:
return []
try:
response = litellm.embedding(
model=self.model,
input=texts,
**self.kwargs
)
return [data['embedding'] for data in response.data]
except Exception as e:
logger.error(f"Error generating batch embeddings: {e}")
raise
class Document(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
content: str
metadata: Dict[str, Any] = Field(default_factory=dict)
data_type: DataType = DataType.TEXT
source: Optional[str] = None
class RAG(Adapter):
collection_name: str = "crewai_knowledge_base"
persist_directory: Optional[str] = None
embedding_model: str = "text-embedding-3-large"
summarize: bool = False
top_k: int = 5
embedding_config: Dict[str, Any] = Field(default_factory=dict)
_client: Any = PrivateAttr()
_collection: Any = PrivateAttr()
_embedding_service: EmbeddingService = PrivateAttr()
def model_post_init(self, __context: Any) -> None:
try:
if self.persist_directory:
self._client = chromadb.PersistentClient(path=self.persist_directory)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine", "description": "CrewAI Knowledge Base"}
)
self._embedding_service = EmbeddingService(model=self.embedding_model, **self.embedding_config)
except Exception as e:
logger.error(f"Failed to initialize ChromaDB: {e}")
raise
super().model_post_init(__context)
def add(
self,
content: str | Path,
data_type: Optional[Union[str, DataType]] = None,
metadata: Optional[Dict[str, Any]] = None,
loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
**kwargs: Any
) -> None:
source_content = SourceContent(content)
data_type = self._get_data_type(data_type=data_type, content=source_content)
if not loader:
loader = data_type.get_loader()
if not chunker:
chunker = data_type.get_chunker()
loader_result = loader.load(source_content)
doc_id = loader_result.doc_id
existing_doc = self._collection.get(where={"source": source_content.source_ref}, limit=1)
existing_doc_id = existing_doc and existing_doc['metadatas'][0]['doc_id'] if existing_doc['metadatas'] else None
if existing_doc_id == doc_id:
logger.warning(f"Document with source {loader_result.source} already exists")
return
# Document with same source ref does exists but the content has changed, deleting the oldest reference
if existing_doc_id and existing_doc_id != loader_result.doc_id:
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
self._collection.delete(where={"doc_id": existing_doc_id})
documents = []
chunks = chunker.chunk(loader_result.content)
for i, chunk in enumerate(chunks):
doc_metadata = (metadata or {}).copy()
doc_metadata['chunk_index'] = i
documents.append(Document(
id=compute_sha256(chunk),
content=chunk,
metadata=doc_metadata,
data_type=data_type,
source=loader_result.source
))
if not documents:
logger.warning("No documents to add")
return
contents = [doc.content for doc in documents]
try:
embeddings = self._embedding_service.embed_batch(contents)
except Exception as e:
logger.error(f"Failed to generate embeddings: {e}")
return
ids = [doc.id for doc in documents]
metadatas = []
for doc in documents:
doc_metadata = doc.metadata.copy()
doc_metadata.update({
"data_type": doc.data_type.value,
"source": doc.source,
"doc_id": doc_id
})
metadatas.append(doc_metadata)
try:
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
)
logger.info(f"Added {len(documents)} documents to knowledge base")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
def query(self, question: str, where: Optional[Dict[str, Any]] = None) -> str:
try:
question_embedding = self._embedding_service.embed_text(question)
results = self._collection.query(
query_embeddings=[question_embedding],
n_results=self.top_k,
where=where,
include=["documents", "metadatas", "distances"]
)
if not results or not results.get("documents") or not results["documents"][0]:
return "No relevant content found."
documents = results["documents"][0]
metadatas = results.get("metadatas", [None])[0] or []
distances = results.get("distances", [None])[0] or []
# Return sources with relevance scores
formatted_results = []
for i, doc in enumerate(documents):
metadata = metadatas[i] if i < len(metadatas) else {}
distance = distances[i] if i < len(distances) else 1.0
source = metadata.get("source", "unknown") if metadata else "unknown"
score = 1 - distance if distance is not None else 0 # Convert distance to similarity
formatted_results.append(f"[Source: {source}, Relevance: {score:.3f}]\n{doc}")
return "\n\n".join(formatted_results)
except Exception as e:
logger.error(f"Query failed: {e}")
return f"Error querying knowledge base: {e}"
def delete_collection(self) -> None:
try:
self._client.delete_collection(self.collection_name)
logger.info(f"Deleted collection: {self.collection_name}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
def get_collection_info(self) -> Dict[str, Any]:
try:
count = self._collection.count()
return {
"name": self.collection_name,
"count": count,
"embedding_model": self.embedding_model
}
except Exception as e:
logger.error(f"Failed to get collection info: {e}")
return {"error": str(e)}
def _get_data_type(self, content: SourceContent, data_type: str | DataType | None = None) -> DataType:
try:
if isinstance(data_type, str):
return DataType(data_type)
except Exception as e:
pass
return content.data_type

View File

@@ -0,0 +1,137 @@
from enum import Enum
from pathlib import Path
from urllib.parse import urlparse
import os
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from crewai_tools.rag.base_loader import BaseLoader
class DataType(str, Enum):
PDF_FILE = "pdf_file"
TEXT_FILE = "text_file"
CSV = "csv"
JSON = "json"
XML = "xml"
DOCX = "docx"
MDX = "mdx"
# Database types
MYSQL = "mysql"
POSTGRES = "postgres"
# Repository types
GITHUB = "github"
DIRECTORY = "directory"
# Web types
WEBSITE = "website"
DOCS_SITE = "docs_site"
# Raw types
TEXT = "text"
def get_chunker(self) -> BaseChunker:
from importlib import import_module
chunkers = {
DataType.TEXT_FILE: ("text_chunker", "TextChunker"),
DataType.TEXT: ("text_chunker", "TextChunker"),
DataType.DOCX: ("text_chunker", "DocxChunker"),
DataType.MDX: ("text_chunker", "MdxChunker"),
# Structured formats
DataType.CSV: ("structured_chunker", "CsvChunker"),
DataType.JSON: ("structured_chunker", "JsonChunker"),
DataType.XML: ("structured_chunker", "XmlChunker"),
DataType.WEBSITE: ("web_chunker", "WebsiteChunker"),
}
module_name, class_name = chunkers.get(self, ("default_chunker", "DefaultChunker"))
module_path = f"crewai_tools.rag.chunkers.{module_name}"
try:
module = import_module(module_path)
return getattr(module, class_name)()
except Exception as e:
raise ValueError(f"Error loading chunker for {self}: {e}")
def get_loader(self) -> BaseLoader:
from importlib import import_module
loaders = {
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
DataType.TEXT: ("text_loader", "TextLoader"),
DataType.XML: ("xml_loader", "XMLLoader"),
DataType.WEBSITE: ("webpage_loader", "WebPageLoader"),
DataType.MDX: ("mdx_loader", "MDXLoader"),
DataType.JSON: ("json_loader", "JSONLoader"),
DataType.DOCX: ("docx_loader", "DOCXLoader"),
DataType.CSV: ("csv_loader", "CSVLoader"),
DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"),
}
module_name, class_name = loaders.get(self, ("text_loader", "TextLoader"))
module_path = f"crewai_tools.rag.loaders.{module_name}"
try:
module = import_module(module_path)
return getattr(module, class_name)()
except Exception as e:
raise ValueError(f"Error loading loader for {self}: {e}")
class DataTypes:
@staticmethod
def from_content(content: str | Path | None = None) -> DataType:
if content is None:
return DataType.TEXT
if isinstance(content, Path):
content = str(content)
is_url = False
if isinstance(content, str):
try:
url = urlparse(content)
is_url = (url.scheme and url.netloc) or url.scheme == "file"
except Exception:
pass
def get_file_type(path: str) -> DataType | None:
mapping = {
".pdf": DataType.PDF_FILE,
".csv": DataType.CSV,
".mdx": DataType.MDX,
".md": DataType.MDX,
".docx": DataType.DOCX,
".json": DataType.JSON,
".xml": DataType.XML,
".txt": DataType.TEXT_FILE,
}
for ext, dtype in mapping.items():
if path.endswith(ext):
return dtype
return None
if is_url:
dtype = get_file_type(url.path)
if dtype:
return dtype
if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
return DataType.DOCS_SITE
if "github.com" in url.netloc:
return DataType.GITHUB
return DataType.WEBSITE
if os.path.isfile(content):
dtype = get_file_type(content)
if dtype:
return dtype
if os.path.exists(content):
return DataType.TEXT_FILE
elif os.path.isdir(content):
return DataType.DIRECTORY
return DataType.TEXT

View File

@@ -0,0 +1,20 @@
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
from crewai_tools.rag.loaders.xml_loader import XMLLoader
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
from crewai_tools.rag.loaders.json_loader import JSONLoader
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
from crewai_tools.rag.loaders.csv_loader import CSVLoader
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
__all__ = [
"TextFileLoader",
"TextLoader",
"XMLLoader",
"WebPageLoader",
"MDXLoader",
"JSONLoader",
"DOCXLoader",
"CSVLoader",
"DirectoryLoader",
]

View File

@@ -0,0 +1,72 @@
import csv
from io import StringIO
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class CSVLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
source_ref = source_content.source_ref
content_str = source_content.source
if source_content.is_url():
content_str = self._load_from_url(content_str, kwargs)
elif source_content.path_exists():
content_str = self._load_from_file(content_str)
return self._parse_csv(content_str, source_ref)
def _load_from_url(self, url: str, kwargs: dict) -> str:
import requests
headers = kwargs.get("headers", {
"Accept": "text/csv, application/csv, text/plain",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)"
})
try:
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
return response.text
except Exception as e:
raise ValueError(f"Error fetching CSV from URL {url}: {str(e)}")
def _load_from_file(self, path: str) -> str:
with open(path, "r", encoding="utf-8") as file:
return file.read()
def _parse_csv(self, content: str, source_ref: str) -> LoaderResult:
try:
csv_reader = csv.DictReader(StringIO(content))
text_parts = []
headers = csv_reader.fieldnames
if headers:
text_parts.append("Headers: " + " | ".join(headers))
text_parts.append("-" * 50)
for row_num, row in enumerate(csv_reader, 1):
row_text = " | ".join([f"{k}: {v}" for k, v in row.items() if v])
text_parts.append(f"Row {row_num}: {row_text}")
text = "\n".join(text_parts)
metadata = {
"format": "csv",
"columns": headers,
"rows": len(text_parts) - 2 if headers else 0
}
except Exception as e:
text = content
metadata = {"format": "csv", "parse_error": str(e)}
return LoaderResult(
content=text,
source=source_ref,
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
)

View File

@@ -0,0 +1,142 @@
import os
from pathlib import Path
from typing import List
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class DirectoryLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
"""
Load and process all files from a directory recursively.
Args:
source: Directory path or URL to a directory listing
**kwargs: Additional options:
- recursive: bool (default True) - Whether to search recursively
- include_extensions: list - Only include files with these extensions
- exclude_extensions: list - Exclude files with these extensions
- max_files: int - Maximum number of files to process
"""
source_ref = source_content.source_ref
if source_content.is_url():
raise ValueError("URL directory loading is not supported. Please provide a local directory path.")
if not os.path.exists(source_ref):
raise FileNotFoundError(f"Directory does not exist: {source_ref}")
if not os.path.isdir(source_ref):
raise ValueError(f"Path is not a directory: {source_ref}")
return self._process_directory(source_ref, kwargs)
def _process_directory(self, dir_path: str, kwargs: dict) -> LoaderResult:
recursive = kwargs.get("recursive", True)
include_extensions = kwargs.get("include_extensions", None)
exclude_extensions = kwargs.get("exclude_extensions", None)
max_files = kwargs.get("max_files", None)
files = self._find_files(dir_path, recursive, include_extensions, exclude_extensions)
if max_files and len(files) > max_files:
files = files[:max_files]
all_contents = []
processed_files = []
errors = []
for file_path in files:
try:
result = self._process_single_file(file_path)
if result:
all_contents.append(f"=== File: {file_path} ===\n{result.content}")
processed_files.append({
"path": file_path,
"metadata": result.metadata,
"source": result.source
})
except Exception as e:
error_msg = f"Error processing {file_path}: {str(e)}"
errors.append(error_msg)
all_contents.append(f"=== File: {file_path} (ERROR) ===\n{error_msg}")
combined_content = "\n\n".join(all_contents)
metadata = {
"format": "directory",
"directory_path": dir_path,
"total_files": len(files),
"processed_files": len(processed_files),
"errors": len(errors),
"file_details": processed_files,
"error_details": errors
}
return LoaderResult(
content=combined_content,
source=dir_path,
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=dir_path, content=combined_content)
)
def _find_files(self, dir_path: str, recursive: bool, include_ext: List[str] | None = None, exclude_ext: List[str] | None = None) -> List[str]:
"""Find all files in directory matching criteria."""
files = []
if recursive:
for root, dirs, filenames in os.walk(dir_path):
dirs[:] = [d for d in dirs if not d.startswith('.')]
for filename in filenames:
if self._should_include_file(filename, include_ext, exclude_ext):
files.append(os.path.join(root, filename))
else:
try:
for item in os.listdir(dir_path):
item_path = os.path.join(dir_path, item)
if os.path.isfile(item_path) and self._should_include_file(item, include_ext, exclude_ext):
files.append(item_path)
except PermissionError:
pass
return sorted(files)
def _should_include_file(self, filename: str, include_ext: List[str] = None, exclude_ext: List[str] = None) -> bool:
"""Determine if a file should be included based on criteria."""
if filename.startswith('.'):
return False
_, ext = os.path.splitext(filename.lower())
if include_ext:
if ext not in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in include_ext]:
return False
if exclude_ext:
if ext in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in exclude_ext]:
return False
return True
def _process_single_file(self, file_path: str) -> LoaderResult:
from crewai_tools.rag.data_types import DataTypes
data_type = DataTypes.from_content(Path(file_path))
loader = data_type.get_loader()
result = loader.load(SourceContent(file_path))
if result.metadata is None:
result.metadata = {}
result.metadata.update({
"file_path": file_path,
"file_size": os.path.getsize(file_path),
"data_type": str(data_type),
"loader_type": loader.__class__.__name__
})
return result

View File

@@ -0,0 +1,72 @@
import os
import tempfile
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class DOCXLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
try:
from docx import Document as DocxDocument
except ImportError:
raise ImportError("python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]")
source_ref = source_content.source_ref
if source_content.is_url():
temp_file = self._download_from_url(source_ref, kwargs)
try:
return self._load_from_file(temp_file, source_ref, DocxDocument)
finally:
os.unlink(temp_file)
elif source_content.path_exists():
return self._load_from_file(source_ref, source_ref, DocxDocument)
else:
raise ValueError(f"Source must be a valid file path or URL, got: {source_content.source}")
def _download_from_url(self, url: str, kwargs: dict) -> str:
import requests
headers = kwargs.get("headers", {
"Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools DOCXLoader)"
})
try:
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
# Create temporary file to save the DOCX content
with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as temp_file:
temp_file.write(response.content)
return temp_file.name
except Exception as e:
raise ValueError(f"Error fetching DOCX from URL {url}: {str(e)}")
def _load_from_file(self, file_path: str, source_ref: str, DocxDocument) -> LoaderResult:
try:
doc = DocxDocument(file_path)
text_parts = []
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text_parts.append(paragraph.text)
content = "\n".join(text_parts)
metadata = {
"format": "docx",
"paragraphs": len(doc.paragraphs),
"tables": len(doc.tables)
}
return LoaderResult(
content=content,
source=source_ref,
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=source_ref, content=content)
)
except Exception as e:
raise ValueError(f"Error loading DOCX file: {str(e)}")

View File

@@ -0,0 +1,69 @@
import json
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
class JSONLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
source_ref = source_content.source_ref
content = source_content.source
if source_content.is_url():
content = self._load_from_url(source_ref, kwargs)
elif source_content.path_exists():
content = self._load_from_file(source_ref)
return self._parse_json(content, source_ref)
def _load_from_url(self, url: str, kwargs: dict) -> str:
import requests
headers = kwargs.get("headers", {
"Accept": "application/json",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)"
})
try:
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
return response.text if not self._is_json_response(response) else json.dumps(response.json(), indent=2)
except Exception as e:
raise ValueError(f"Error fetching JSON from URL {url}: {str(e)}")
def _is_json_response(self, response) -> bool:
try:
response.json()
return True
except ValueError:
return False
def _load_from_file(self, path: str) -> str:
with open(path, "r", encoding="utf-8") as file:
return file.read()
def _parse_json(self, content: str, source_ref: str) -> LoaderResult:
try:
data = json.loads(content)
if isinstance(data, dict):
text = "\n".join(f"{k}: {json.dumps(v, indent=0)}" for k, v in data.items())
elif isinstance(data, list):
text = "\n".join(json.dumps(item, indent=0) for item in data)
else:
text = json.dumps(data, indent=0)
metadata = {
"format": "json",
"type": type(data).__name__,
"size": len(data) if isinstance(data, (list, dict)) else 1
}
except json.JSONDecodeError as e:
text = content
metadata = {"format": "json", "parse_error": str(e)}
return LoaderResult(
content=text,
source=source_ref,
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
)

View File

@@ -0,0 +1,59 @@
import re
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class MDXLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
source_ref = source_content.source_ref
content = source_content.source
if source_content.is_url():
content = self._load_from_url(source_ref, kwargs)
elif source_content.path_exists():
content = self._load_from_file(source_ref)
return self._parse_mdx(content, source_ref)
def _load_from_url(self, url: str, kwargs: dict) -> str:
import requests
headers = kwargs.get("headers", {
"Accept": "text/markdown, text/x-markdown, text/plain",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)"
})
try:
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
return response.text
except Exception as e:
raise ValueError(f"Error fetching MDX from URL {url}: {str(e)}")
def _load_from_file(self, path: str) -> str:
with open(path, "r", encoding="utf-8") as file:
return file.read()
def _parse_mdx(self, content: str, source_ref: str) -> LoaderResult:
cleaned_content = content
# Remove import statements
cleaned_content = re.sub(r'^import\s+.*?\n', '', cleaned_content, flags=re.MULTILINE)
# Remove export statements
cleaned_content = re.sub(r'^export\s+.*?(?:\n|$)', '', cleaned_content, flags=re.MULTILINE)
# Remove JSX tags (simple approach)
cleaned_content = re.sub(r'<[^>]+>', '', cleaned_content)
# Clean up extra whitespace
cleaned_content = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_content)
cleaned_content = cleaned_content.strip()
metadata = {"format": "mdx"}
return LoaderResult(
content=cleaned_content,
source=source_ref,
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=source_ref, content=cleaned_content)
)

View File

@@ -0,0 +1,28 @@
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TextFileLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
source_ref = source_content.source_ref
if not source_content.path_exists():
raise FileNotFoundError(f"The following file does not exist: {source_content.source}")
with open(source_content.source, "r", encoding="utf-8") as file:
content = file.read()
return LoaderResult(
content=content,
source=source_ref,
doc_id=self.generate_doc_id(source_ref=source_ref, content=content)
)
class TextLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
return LoaderResult(
content=source_content.source,
source=source_content.source_ref,
doc_id=self.generate_doc_id(content=source_content.source)
)

View File

@@ -0,0 +1,47 @@
import re
import requests
from bs4 import BeautifulSoup
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class WebPageLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
url = source_content.source
headers = kwargs.get("headers", {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
"Accept-Language": "en-US,en;q=0.9",
})
try:
response = requests.get(url, timeout=15, headers=headers)
response.encoding = response.apparent_encoding
soup = BeautifulSoup(response.text, "html.parser")
for script in soup(["script", "style"]):
script.decompose()
text = soup.get_text(" ")
text = re.sub("[ \t]+", " ", text)
text = re.sub("\\s+\n\\s+", "\n", text)
text = text.strip()
title = soup.title.string.strip() if soup.title and soup.title.string else ""
metadata = {
"url": url,
"title": title,
"status_code": response.status_code,
"content_type": response.headers.get("content-type", "")
}
return LoaderResult(
content=text,
source=url,
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=url, content=text)
)
except Exception as e:
raise ValueError(f"Error loading webpage {url}: {str(e)}")

View File

@@ -0,0 +1,61 @@
import os
import xml.etree.ElementTree as ET
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class XMLLoader(BaseLoader):
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
source_ref = source_content.source_ref
content = source_content.source
if source_content.is_url():
content = self._load_from_url(source_ref, kwargs)
elif os.path.exists(source_ref):
content = self._load_from_file(source_ref)
return self._parse_xml(content, source_ref)
def _load_from_url(self, url: str, kwargs: dict) -> str:
import requests
headers = kwargs.get("headers", {
"Accept": "application/xml, text/xml, text/plain",
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)"
})
try:
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
return response.text
except Exception as e:
raise ValueError(f"Error fetching XML from URL {url}: {str(e)}")
def _load_from_file(self, path: str) -> str:
with open(path, "r", encoding="utf-8") as file:
return file.read()
def _parse_xml(self, content: str, source_ref: str) -> LoaderResult:
try:
if content.strip().startswith('<'):
root = ET.fromstring(content)
else:
root = ET.parse(source_ref).getroot()
text_parts = []
for text_content in root.itertext():
if text_content and text_content.strip():
text_parts.append(text_content.strip())
text = "\n".join(text_parts)
metadata = {"format": "xml", "root_tag": root.tag}
except ET.ParseError as e:
text = content
metadata = {"format": "xml", "parse_error": str(e)}
return LoaderResult(
content=text,
source=source_ref,
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
)

View File

@@ -0,0 +1,4 @@
import hashlib
def compute_sha256(content: str) -> str:
return hashlib.sha256(content.encode("utf-8")).hexdigest()

View File

@@ -0,0 +1,46 @@
import os
from urllib.parse import urlparse
from typing import TYPE_CHECKING
from pathlib import Path
from functools import cached_property
from crewai_tools.rag.misc import compute_sha256
if TYPE_CHECKING:
from crewai_tools.rag.data_types import DataType
class SourceContent:
def __init__(self, source: str | Path):
self.source = str(source)
def is_url(self) -> bool:
if not isinstance(self.source, str):
return False
try:
parsed_url = urlparse(self.source)
return bool(parsed_url.scheme and parsed_url.netloc)
except Exception:
return False
def path_exists(self) -> bool:
return os.path.exists(self.source)
@cached_property
def data_type(self) -> "DataType":
from crewai_tools.rag.data_types import DataTypes
return DataTypes.from_content(self.source)
@cached_property
def source_ref(self) -> str:
""""
Returns the source reference for the content.
If the content is a URL or a local file, returns the source.
Otherwise, returns the hash of the content.
"""
if self.is_url() or self.path_exists():
return self.source
return compute_sha256(self.source)

0
tests/rag/__init__.py Normal file
View File

View File

@@ -0,0 +1,130 @@
import os
import tempfile
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.csv_loader import CSVLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
@pytest.fixture
def temp_csv_file():
created_files = []
def _create(content: str):
f = tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False)
f.write(content)
f.close()
created_files.append(f.name)
return f.name
yield _create
for path in created_files:
os.unlink(path)
class TestCSVLoader:
def test_load_csv_from_file(self, temp_csv_file):
path = temp_csv_file("name,age,city\nJohn,25,New York\nJane,30,Chicago")
loader = CSVLoader()
result = loader.load(SourceContent(path))
assert isinstance(result, LoaderResult)
assert "Headers: name | age | city" in result.content
assert "Row 1: name: John | age: 25 | city: New York" in result.content
assert "Row 2: name: Jane | age: 30 | city: Chicago" in result.content
assert result.metadata == {
"format": "csv",
"columns": ["name", "age", "city"],
"rows": 2,
}
assert result.source == path
assert result.doc_id
def test_load_csv_with_empty_values(self, temp_csv_file):
path = temp_csv_file("name,age,city\nJohn,,New York\n,30,")
result = CSVLoader().load(SourceContent(path))
assert "Row 1: name: John | city: New York" in result.content
assert "Row 2: age: 30" in result.content
assert result.metadata["rows"] == 2
def test_load_csv_malformed(self, temp_csv_file):
path = temp_csv_file("invalid,csv\nunclosed quote \"missing")
result = CSVLoader().load(SourceContent(path))
assert "Headers: invalid | csv" in result.content
assert 'Row 1: invalid: unclosed quote "missing' in result.content
assert result.metadata["columns"] == ["invalid", "csv"]
def test_load_csv_empty_file(self, temp_csv_file):
path = temp_csv_file("")
result = CSVLoader().load(SourceContent(path))
assert result.content == ""
assert result.metadata["rows"] == 0
def test_load_csv_text_input(self):
raw_csv = "col1,col2\nvalue1,value2\nvalue3,value4"
result = CSVLoader().load(SourceContent(raw_csv))
assert "Headers: col1 | col2" in result.content
assert "Row 1: col1: value1 | col2: value2" in result.content
assert "Row 2: col1: value3 | col2: value4" in result.content
assert result.metadata["columns"] == ["col1", "col2"]
assert result.metadata["rows"] == 2
def test_doc_id_is_deterministic(self, temp_csv_file):
path = temp_csv_file("name,value\ntest,123")
loader = CSVLoader()
result1 = loader.load(SourceContent(path))
result2 = loader.load(SourceContent(path))
assert result1.doc_id == result2.doc_id
@patch("requests.get")
def test_load_csv_from_url(self, mock_get):
mock_get.return_value = Mock(
text="name,value\ntest,123",
raise_for_status=Mock(return_value=None)
)
result = CSVLoader().load(SourceContent("https://example.com/data.csv"))
assert "Headers: name | value" in result.content
assert "Row 1: name: test | value: 123" in result.content
headers = mock_get.call_args[1]["headers"]
assert "text/csv" in headers["Accept"]
assert "crewai-tools CSVLoader" in headers["User-Agent"]
@patch("requests.get")
def test_load_csv_with_custom_headers(self, mock_get):
mock_get.return_value = Mock(
text="data,value\ntest,456",
raise_for_status=Mock(return_value=None)
)
headers = {"Authorization": "Bearer token", "Custom-Header": "value"}
result = CSVLoader().load(SourceContent("https://example.com/data.csv"), headers=headers)
assert "Headers: data | value" in result.content
assert mock_get.call_args[1]["headers"] == headers
@patch("requests.get")
def test_csv_loader_handles_network_errors(self, mock_get):
mock_get.side_effect = Exception("Network error")
loader = CSVLoader()
with pytest.raises(ValueError, match="Error fetching CSV from URL"):
loader.load(SourceContent("https://example.com/data.csv"))
@patch("requests.get")
def test_csv_loader_handles_http_error(self, mock_get):
mock_get.return_value = Mock()
mock_get.return_value.raise_for_status.side_effect = Exception("404 Not Found")
loader = CSVLoader()
with pytest.raises(ValueError, match="Error fetching CSV from URL"):
loader.load(SourceContent("https://example.com/notfound.csv"))

View File

@@ -0,0 +1,149 @@
import os
import tempfile
import pytest
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
@pytest.fixture
def temp_directory():
with tempfile.TemporaryDirectory() as temp_dir:
yield temp_dir
class TestDirectoryLoader:
def _create_file(self, directory, filename, content="test content"):
path = os.path.join(directory, filename)
with open(path, "w") as f:
f.write(content)
return path
def test_load_non_recursive(self, temp_directory):
self._create_file(temp_directory, "file1.txt")
self._create_file(temp_directory, "file2.txt")
subdir = os.path.join(temp_directory, "subdir")
os.makedirs(subdir)
self._create_file(subdir, "file3.txt")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory), recursive=False)
assert isinstance(result, LoaderResult)
assert "file1.txt" in result.content
assert "file2.txt" in result.content
assert "file3.txt" not in result.content
assert result.metadata["total_files"] == 2
def test_load_recursive(self, temp_directory):
self._create_file(temp_directory, "file1.txt")
nested = os.path.join(temp_directory, "subdir", "nested")
os.makedirs(nested)
self._create_file(os.path.join(temp_directory, "subdir"), "file2.txt")
self._create_file(nested, "file3.txt")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory), recursive=True)
assert all(f"file{i}.txt" in result.content for i in range(1, 4))
def test_include_and_exclude_extensions(self, temp_directory):
self._create_file(temp_directory, "a.txt")
self._create_file(temp_directory, "b.py")
self._create_file(temp_directory, "c.md")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory), include_extensions=[".txt", ".py"])
assert "a.txt" in result.content
assert "b.py" in result.content
assert "c.md" not in result.content
result2 = loader.load(SourceContent(temp_directory), exclude_extensions=[".py", ".md"])
assert "a.txt" in result2.content
assert "b.py" not in result2.content
assert "c.md" not in result2.content
def test_max_files_limit(self, temp_directory):
for i in range(5):
self._create_file(temp_directory, f"file{i}.txt")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory), max_files=3)
assert result.metadata["total_files"] == 3
assert all(f"file{i}.txt" in result.content for i in range(3))
def test_hidden_files_and_dirs_excluded(self, temp_directory):
self._create_file(temp_directory, "visible.txt", "visible")
self._create_file(temp_directory, ".hidden.txt", "hidden")
hidden_dir = os.path.join(temp_directory, ".hidden")
os.makedirs(hidden_dir)
self._create_file(hidden_dir, "inside_hidden.txt")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory), recursive=True)
assert "visible.txt" in result.content
assert ".hidden.txt" not in result.content
assert "inside_hidden.txt" not in result.content
def test_directory_does_not_exist(self):
loader = DirectoryLoader()
with pytest.raises(FileNotFoundError, match="Directory does not exist"):
loader.load(SourceContent("/path/does/not/exist"))
def test_path_is_not_a_directory(self):
with tempfile.NamedTemporaryFile() as f:
loader = DirectoryLoader()
with pytest.raises(ValueError, match="Path is not a directory"):
loader.load(SourceContent(f.name))
def test_url_not_supported(self):
loader = DirectoryLoader()
with pytest.raises(ValueError, match="URL directory loading is not supported"):
loader.load(SourceContent("https://example.com"))
def test_processing_error_handling(self, temp_directory):
self._create_file(temp_directory, "valid.txt")
error_file = self._create_file(temp_directory, "error.txt")
loader = DirectoryLoader()
original_method = loader._process_single_file
def mock(file_path):
if "error" in file_path:
raise ValueError("Mock error")
return original_method(file_path)
loader._process_single_file = mock
result = loader.load(SourceContent(temp_directory))
assert "valid.txt" in result.content
assert "error.txt (ERROR)" in result.content
assert result.metadata["errors"] == 1
assert len(result.metadata["error_details"]) == 1
def test_metadata_structure(self, temp_directory):
self._create_file(temp_directory, "test.txt", "Sample")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory))
metadata = result.metadata
expected_keys = {
"format", "directory_path", "total_files", "processed_files",
"errors", "file_details", "error_details"
}
assert expected_keys.issubset(metadata)
assert all(k in metadata["file_details"][0] for k in ("path", "metadata", "source"))
def test_empty_directory(self, temp_directory):
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory))
assert result.content == ""
assert result.metadata["total_files"] == 0
assert result.metadata["processed_files"] == 0

View File

@@ -0,0 +1,135 @@
import tempfile
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TestDOCXLoader:
@patch('docx.Document')
def test_load_docx_from_file(self, mock_docx_class):
mock_doc = Mock()
mock_doc.paragraphs = [
Mock(text="First paragraph"),
Mock(text="Second paragraph"),
Mock(text=" ") # Blank paragraph
]
mock_doc.tables = []
mock_docx_class.return_value = mock_doc
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
loader = DOCXLoader()
result = loader.load(SourceContent(f.name))
assert isinstance(result, LoaderResult)
assert result.content == "First paragraph\nSecond paragraph"
assert result.metadata == {"format": "docx", "paragraphs": 3, "tables": 0}
assert result.source == f.name
@patch('docx.Document')
def test_load_docx_with_tables(self, mock_docx_class):
mock_doc = Mock()
mock_doc.paragraphs = [Mock(text="Document with table")]
mock_doc.tables = [Mock(), Mock()]
mock_docx_class.return_value = mock_doc
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
loader = DOCXLoader()
result = loader.load(SourceContent(f.name))
assert result.metadata["tables"] == 2
@patch('requests.get')
@patch('docx.Document')
@patch('tempfile.NamedTemporaryFile')
@patch('os.unlink')
def test_load_docx_from_url(self, mock_unlink, mock_tempfile, mock_docx_class, mock_get):
mock_get.return_value = Mock(content=b"fake docx content", raise_for_status=Mock())
mock_temp = Mock(name="/tmp/temp_docx_file.docx")
mock_temp.__enter__ = Mock(return_value=mock_temp)
mock_temp.__exit__ = Mock(return_value=None)
mock_tempfile.return_value = mock_temp
mock_doc = Mock()
mock_doc.paragraphs = [Mock(text="Content from URL")]
mock_doc.tables = []
mock_docx_class.return_value = mock_doc
loader = DOCXLoader()
result = loader.load(SourceContent("https://example.com/test.docx"))
assert "Content from URL" in result.content
assert result.source == "https://example.com/test.docx"
headers = mock_get.call_args[1]['headers']
assert "application/vnd.openxmlformats-officedocument.wordprocessingml.document" in headers['Accept']
assert "crewai-tools DOCXLoader" in headers['User-Agent']
mock_temp.write.assert_called_once_with(b"fake docx content")
@patch('requests.get')
@patch('docx.Document')
def test_load_docx_from_url_with_custom_headers(self, mock_docx_class, mock_get):
mock_get.return_value = Mock(content=b"fake docx content", raise_for_status=Mock())
mock_docx_class.return_value = Mock(paragraphs=[], tables=[])
loader = DOCXLoader()
custom_headers = {"Authorization": "Bearer token"}
with patch('tempfile.NamedTemporaryFile'), patch('os.unlink'):
loader.load(SourceContent("https://example.com/test.docx"), headers=custom_headers)
assert mock_get.call_args[1]['headers'] == custom_headers
@patch('requests.get')
def test_load_docx_url_download_error(self, mock_get):
mock_get.side_effect = Exception("Network error")
loader = DOCXLoader()
with pytest.raises(ValueError, match="Error fetching DOCX from URL"):
loader.load(SourceContent("https://example.com/test.docx"))
@patch('requests.get')
def test_load_docx_url_http_error(self, mock_get):
mock_get.return_value = Mock(raise_for_status=Mock(side_effect=Exception("404 Not Found")))
loader = DOCXLoader()
with pytest.raises(ValueError, match="Error fetching DOCX from URL"):
loader.load(SourceContent("https://example.com/notfound.docx"))
def test_load_docx_invalid_source(self):
loader = DOCXLoader()
with pytest.raises(ValueError, match="Source must be a valid file path or URL"):
loader.load(SourceContent("not_a_file_or_url"))
@patch('docx.Document')
def test_load_docx_parsing_error(self, mock_docx_class):
mock_docx_class.side_effect = Exception("Invalid DOCX file")
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
loader = DOCXLoader()
with pytest.raises(ValueError, match="Error loading DOCX file"):
loader.load(SourceContent(f.name))
@patch('docx.Document')
def test_load_docx_empty_document(self, mock_docx_class):
mock_docx_class.return_value = Mock(paragraphs=[], tables=[])
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
loader = DOCXLoader()
result = loader.load(SourceContent(f.name))
assert result.content == ""
assert result.metadata == {"paragraphs": 0, "tables": 0, "format": "docx"}
@patch('docx.Document')
def test_docx_doc_id_generation(self, mock_docx_class):
mock_docx_class.return_value = Mock(paragraphs=[Mock(text="Consistent content")], tables=[])
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
loader = DOCXLoader()
source = SourceContent(f.name)
assert loader.load(source).doc_id == loader.load(source).doc_id

View File

@@ -0,0 +1,180 @@
import json
import os
import tempfile
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.json_loader import JSONLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TestJSONLoader:
def _create_temp_json_file(self, data) -> str:
"""Helper to write JSON data to a temporary file and return its path."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(data, f)
return f.name
def _create_temp_raw_file(self, content: str) -> str:
"""Helper to write raw content to a temporary file and return its path."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
f.write(content)
return f.name
def _load_from_path(self, path) -> LoaderResult:
loader = JSONLoader()
return loader.load(SourceContent(path))
def test_load_json_dict(self):
path = self._create_temp_json_file({"name": "John", "age": 30, "items": ["a", "b", "c"]})
try:
result = self._load_from_path(path)
assert isinstance(result, LoaderResult)
assert all(k in result.content for k in ["name", "John", "age", "30"])
assert result.metadata == {
"format": "json", "type": "dict", "size": 3
}
assert result.source == path
finally:
os.unlink(path)
def test_load_json_list(self):
path = self._create_temp_json_file([
{"id": 1, "name": "Item 1"},
{"id": 2, "name": "Item 2"},
])
try:
result = self._load_from_path(path)
assert result.metadata["type"] == "list"
assert result.metadata["size"] == 2
assert all(item in result.content for item in ["Item 1", "Item 2"])
finally:
os.unlink(path)
@pytest.mark.parametrize("value, expected_type", [
("simple string value", "str"),
(42, "int"),
])
def test_load_json_primitives(self, value, expected_type):
path = self._create_temp_json_file(value)
try:
result = self._load_from_path(path)
assert result.metadata["type"] == expected_type
assert result.metadata["size"] == 1
assert str(value) in result.content
finally:
os.unlink(path)
def test_load_malformed_json(self):
path = self._create_temp_raw_file('{"invalid": json,}')
try:
result = self._load_from_path(path)
assert result.metadata["format"] == "json"
assert "parse_error" in result.metadata
assert result.content == '{"invalid": json,}'
finally:
os.unlink(path)
def test_load_empty_file(self):
path = self._create_temp_raw_file('')
try:
result = self._load_from_path(path)
assert "parse_error" in result.metadata
assert result.content == ''
finally:
os.unlink(path)
def test_load_text_input(self):
json_text = '{"message": "hello", "count": 5}'
loader = JSONLoader()
result = loader.load(SourceContent(json_text))
assert all(part in result.content for part in ["message", "hello", "count", "5"])
assert result.metadata["type"] == "dict"
assert result.metadata["size"] == 2
def test_load_complex_nested_json(self):
data = {
"users": [
{"id": 1, "profile": {"name": "Alice", "settings": {"theme": "dark"}}},
{"id": 2, "profile": {"name": "Bob", "settings": {"theme": "light"}}}
],
"meta": {"total": 2, "version": "1.0"}
}
path = self._create_temp_json_file(data)
try:
result = self._load_from_path(path)
for value in ["Alice", "Bob", "dark", "light"]:
assert value in result.content
assert result.metadata["size"] == 2 # top-level keys
finally:
os.unlink(path)
def test_consistent_doc_id(self):
path = self._create_temp_json_file({"test": "data"})
try:
result1 = self._load_from_path(path)
result2 = self._load_from_path(path)
assert result1.doc_id == result2.doc_id
finally:
os.unlink(path)
# ------------------------------
# URL-based tests
# ------------------------------
@patch('requests.get')
def test_url_response_valid_json(self, mock_get):
mock_get.return_value = Mock(
text='{"key": "value", "number": 123}',
json=Mock(return_value={"key": "value", "number": 123}),
raise_for_status=Mock()
)
loader = JSONLoader()
result = loader.load(SourceContent("https://api.example.com/data.json"))
assert all(val in result.content for val in ["key", "value", "number", "123"])
headers = mock_get.call_args[1]['headers']
assert "application/json" in headers['Accept']
assert "crewai-tools JSONLoader" in headers['User-Agent']
@patch('requests.get')
def test_url_response_not_json(self, mock_get):
mock_get.return_value = Mock(
text='{"key": "value"}',
json=Mock(side_effect=ValueError("Not JSON")),
raise_for_status=Mock()
)
loader = JSONLoader()
result = loader.load(SourceContent("https://example.com/data.json"))
assert all(part in result.content for part in ["key", "value"])
@patch('requests.get')
def test_url_with_custom_headers(self, mock_get):
mock_get.return_value = Mock(
text='{"data": "test"}',
json=Mock(return_value={"data": "test"}),
raise_for_status=Mock()
)
headers = {"Authorization": "Bearer token", "Custom-Header": "value"}
loader = JSONLoader()
loader.load(SourceContent("https://api.example.com/data.json"), headers=headers)
assert mock_get.call_args[1]['headers'] == headers
@patch('requests.get')
def test_url_network_failure(self, mock_get):
mock_get.side_effect = Exception("Network error")
loader = JSONLoader()
with pytest.raises(ValueError, match="Error fetching JSON from URL"):
loader.load(SourceContent("https://api.example.com/data.json"))
@patch('requests.get')
def test_url_http_error(self, mock_get):
mock_get.return_value = Mock(raise_for_status=Mock(side_effect=Exception("404")))
loader = JSONLoader()
with pytest.raises(ValueError, match="Error fetching JSON from URL"):
loader.load(SourceContent("https://api.example.com/404.json"))

View File

@@ -0,0 +1,176 @@
import os
import tempfile
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TestMDXLoader:
def _write_temp_mdx(self, content):
f = tempfile.NamedTemporaryFile(mode='w', suffix='.mdx', delete=False)
f.write(content)
f.close()
return f.name
def _load_from_file(self, content):
path = self._write_temp_mdx(content)
try:
loader = MDXLoader()
return loader.load(SourceContent(path)), path
finally:
os.unlink(path)
def test_load_basic_mdx_file(self):
content = """
import Component from './Component'
export const meta = { title: 'Test' }
# Test MDX File
This is a **markdown** file with JSX.
<Component prop="value" />
Some more content.
<div className="container">
<p>Nested content</p>
</div>
"""
result, path = self._load_from_file(content)
assert isinstance(result, LoaderResult)
assert all(tag not in result.content for tag in ["import", "export", "<Component", "<div", "</div>"])
assert all(text in result.content for text in ["# Test MDX File", "markdown", "Some more content", "Nested content"])
assert result.metadata["format"] == "mdx"
assert result.source == path
def test_mdx_multiple_imports_exports(self):
content = """
import React from 'react'
import { useState } from 'react'
import CustomComponent from './custom'
export default function Layout() { return null }
export const config = { test: true }
# Content
Regular markdown content here.
"""
result, _ = self._load_from_file(content)
assert "# Content" in result.content
assert "Regular markdown content here." in result.content
assert "import" not in result.content and "export" not in result.content
def test_complex_jsx_cleanup(self):
content = """
# MDX with Complex JSX
<div className="alert alert-info">
<strong>Info:</strong> This is important information.
<ul><li>Item 1</li><li>Item 2</li></ul>
</div>
Regular paragraph text.
<MyComponent prop1="value1">Nested content inside component</MyComponent>
"""
result, _ = self._load_from_file(content)
assert all(tag not in result.content for tag in ["<div", "<strong>", "<ul>", "<MyComponent"])
assert all(text in result.content for text in ["Info:", "Item 1", "Regular paragraph text.", "Nested content inside component"])
def test_whitespace_cleanup(self):
content = """
# Title
Some content.
More content after multiple newlines.
Final content.
"""
result, _ = self._load_from_file(content)
assert result.content.count('\n\n\n') == 0
assert result.content.startswith('# Title')
assert result.content.endswith('Final content.')
def test_only_jsx_content(self):
content = """
<div>
<h1>Only JSX content</h1>
<p>No markdown here</p>
</div>
"""
result, _ = self._load_from_file(content)
assert all(tag not in result.content for tag in ["<div>", "<h1>", "<p>"])
assert "Only JSX content" in result.content
assert "No markdown here" in result.content
@patch('requests.get')
def test_load_mdx_from_url(self, mock_get):
mock_get.return_value = Mock(text="# MDX from URL\n\nContent here.\n\n<Component />", raise_for_status=lambda: None)
loader = MDXLoader()
result = loader.load(SourceContent("https://example.com/content.mdx"))
assert "# MDX from URL" in result.content
assert "<Component />" not in result.content
@patch('requests.get')
def test_load_mdx_with_custom_headers(self, mock_get):
mock_get.return_value = Mock(text="# Custom headers test", raise_for_status=lambda: None)
loader = MDXLoader()
loader.load(SourceContent("https://example.com"), headers={"Authorization": "Bearer token"})
assert mock_get.call_args[1]['headers'] == {"Authorization": "Bearer token"}
@patch('requests.get')
def test_mdx_url_fetch_error(self, mock_get):
mock_get.side_effect = Exception("Network error")
with pytest.raises(ValueError, match="Error fetching MDX from URL"):
MDXLoader().load(SourceContent("https://example.com"))
def test_load_inline_mdx_text(self):
content = """# Inline MDX\n\nimport Something from 'somewhere'\n\nContent with <Component prop=\"value\" />.\n\nexport const meta = { title: 'Test' }"""
loader = MDXLoader()
result = loader.load(SourceContent(content))
assert "# Inline MDX" in result.content
assert "Content with ." in result.content
def test_empty_result_after_cleaning(self):
content = """
import Something from 'somewhere'
export const config = {}
<div></div>
"""
result, _ = self._load_from_file(content)
assert result.content.strip() == ""
def test_edge_case_parsing(self):
content = """
# Title
<Component>
Multi-line
JSX content
</Component>
import { a, b } from 'module'
export { x, y }
Final text.
"""
result, _ = self._load_from_file(content)
assert "# Title" in result.content
assert "JSX content" in result.content
assert "Final text." in result.content
assert all(phrase not in result.content for phrase in ["import {", "export {", "<Component>"])

View File

@@ -0,0 +1,160 @@
import hashlib
import os
import tempfile
import pytest
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
def write_temp_file(content, suffix=".txt", encoding="utf-8"):
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False, encoding=encoding) as f:
f.write(content)
return f.name
def cleanup_temp_file(path):
try:
os.unlink(path)
except FileNotFoundError:
pass
class TestTextFileLoader:
def test_basic_text_file(self):
content = "This is test content\nWith multiple lines\nAnd more text"
path = write_temp_file(content)
try:
result = TextFileLoader().load(SourceContent(path))
assert isinstance(result, LoaderResult)
assert result.content == content
assert result.source == path
assert result.doc_id
assert result.metadata in (None, {})
finally:
cleanup_temp_file(path)
def test_empty_file(self):
path = write_temp_file("")
try:
result = TextFileLoader().load(SourceContent(path))
assert result.content == ""
finally:
cleanup_temp_file(path)
def test_unicode_content(self):
content = "Hello 世界 🌍 émojis 🎉 åäö"
path = write_temp_file(content)
try:
result = TextFileLoader().load(SourceContent(path))
assert content in result.content
finally:
cleanup_temp_file(path)
def test_large_file(self):
content = "\n".join(f"Line {i}" for i in range(100))
path = write_temp_file(content)
try:
result = TextFileLoader().load(SourceContent(path))
assert "Line 0" in result.content
assert "Line 99" in result.content
assert result.content.count("\n") == 99
finally:
cleanup_temp_file(path)
def test_missing_file(self):
with pytest.raises(FileNotFoundError):
TextFileLoader().load(SourceContent("/nonexistent/path.txt"))
def test_permission_denied(self):
path = write_temp_file("Some content")
os.chmod(path, 0o000)
try:
with pytest.raises(PermissionError):
TextFileLoader().load(SourceContent(path))
finally:
os.chmod(path, 0o644)
cleanup_temp_file(path)
def test_doc_id_consistency(self):
content = "Consistent content"
path = write_temp_file(content)
try:
loader = TextFileLoader()
result1 = loader.load(SourceContent(path))
result2 = loader.load(SourceContent(path))
expected_id = hashlib.sha256((path + content).encode("utf-8")).hexdigest()
assert result1.doc_id == result2.doc_id == expected_id
finally:
cleanup_temp_file(path)
def test_various_extensions(self):
content = "Same content"
for ext in [".txt", ".md", ".log", ".json"]:
path = write_temp_file(content, suffix=ext)
try:
result = TextFileLoader().load(SourceContent(path))
assert result.content == content
finally:
cleanup_temp_file(path)
class TestTextLoader:
def test_basic_text(self):
content = "Raw text"
result = TextLoader().load(SourceContent(content))
expected_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
assert result.content == content
assert result.source == expected_hash
assert result.doc_id == expected_hash
def test_multiline_text(self):
content = "Line 1\nLine 2\nLine 3"
result = TextLoader().load(SourceContent(content))
assert "Line 2" in result.content
def test_empty_text(self):
result = TextLoader().load(SourceContent(""))
assert result.content == ""
assert result.source == hashlib.sha256("".encode("utf-8")).hexdigest()
def test_unicode_text(self):
content = "世界 🌍 émojis 🎉 åäö"
result = TextLoader().load(SourceContent(content))
assert content in result.content
def test_special_characters(self):
content = "!@#$$%^&*()_+-=~`{}[]\\|;:'\",.<>/?"
result = TextLoader().load(SourceContent(content))
assert result.content == content
def test_doc_id_uniqueness(self):
result1 = TextLoader().load(SourceContent("A"))
result2 = TextLoader().load(SourceContent("B"))
assert result1.doc_id != result2.doc_id
def test_whitespace_text(self):
content = " \n\t "
result = TextLoader().load(SourceContent(content))
assert result.content == content
def test_long_text(self):
content = "A" * 10000
result = TextLoader().load(SourceContent(content))
assert len(result.content) == 10000
class TestTextLoadersIntegration:
def test_consistency_between_loaders(self):
content = "Consistent content"
text_result = TextLoader().load(SourceContent(content))
file_path = write_temp_file(content)
try:
file_result = TextFileLoader().load(SourceContent(file_path))
assert text_result.content == file_result.content
assert text_result.source != file_result.source
assert text_result.doc_id != file_result.doc_id
finally:
cleanup_temp_file(file_path)

View File

@@ -0,0 +1,137 @@
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TestWebPageLoader:
def setup_mock_response(self, text, status_code=200, content_type="text/html"):
response = Mock()
response.text = text
response.apparent_encoding = "utf-8"
response.status_code = status_code
response.headers = {"content-type": content_type}
return response
def setup_mock_soup(self, text, title=None, script_style_elements=None):
soup = Mock()
soup.get_text.return_value = text
soup.title = Mock(string=title) if title is not None else None
soup.return_value = script_style_elements or []
return soup
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_load_basic_webpage(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><head><title>Test Page</title></head><body><p>Test content</p></body></html>")
mock_bs.return_value = self.setup_mock_soup("Test content", title="Test Page")
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com"))
assert isinstance(result, LoaderResult)
assert result.content == "Test content"
assert result.metadata["title"] == "Test Page"
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get):
html = """
<html><head><title>Page with Scripts</title><style>body { color: red; }</style></head>
<body><script>console.log('test');</script><p>Visible content</p></body></html>
"""
mock_get.return_value = self.setup_mock_response(html)
scripts = [Mock(), Mock()]
styles = [Mock()]
for el in scripts + styles:
el.decompose = Mock()
mock_bs.return_value = self.setup_mock_soup("Page with Scripts Visible content", title="Page with Scripts", script_style_elements=scripts + styles)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com/with-scripts"))
assert "Visible content" in result.content
for el in scripts + styles:
el.decompose.assert_called_once()
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_text_cleaning_and_title_handling(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body><p> Messy text </p></body></html>")
mock_bs.return_value = self.setup_mock_soup("Text with extra spaces\n\n More\t text \n\n", title=None)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com/messy-text"))
assert result.content is not None
assert result.metadata["title"] == ""
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_empty_or_missing_title(self, mock_bs, mock_get):
for title in [None, ""]:
mock_get.return_value = self.setup_mock_response("<html><head><title></title></head><body>Content</body></html>")
mock_bs.return_value = self.setup_mock_soup("Content", title=title)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com"))
assert result.metadata["title"] == ""
@patch('requests.get')
def test_custom_and_default_headers(self, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body>Test</body></html>")
custom_headers = {"User-Agent": "Bot", "Authorization": "Bearer xyz", "Accept": "text/html"}
with patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') as mock_bs:
mock_bs.return_value = self.setup_mock_soup("Test")
WebPageLoader().load(SourceContent("https://example.com"), headers=custom_headers)
assert mock_get.call_args[1]['headers'] == custom_headers
@patch('requests.get')
def test_error_handling(self, mock_get):
for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]:
mock_get.side_effect = error
with pytest.raises(ValueError, match="Error loading webpage"):
WebPageLoader().load(SourceContent("https://example.com"))
@patch('requests.get')
def test_timeout_and_http_error(self, mock_get):
import requests
mock_get.side_effect = requests.Timeout("Timeout")
with pytest.raises(ValueError):
WebPageLoader().load(SourceContent("https://example.com"))
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404")
mock_get.side_effect = None
mock_get.return_value = mock_response
with pytest.raises(ValueError):
WebPageLoader().load(SourceContent("https://example.com/404"))
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_doc_id_consistency(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body>Doc</body></html>")
mock_bs.return_value = self.setup_mock_soup("Doc")
loader = WebPageLoader()
result1 = loader.load(SourceContent("https://example.com"))
result2 = loader.load(SourceContent("https://example.com"))
assert result1.doc_id == result2.doc_id
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_status_code_and_content_type(self, mock_bs, mock_get):
for status in [200, 201, 301]:
mock_get.return_value = self.setup_mock_response(f"<html><body>Status {status}</body></html>", status_code=status)
mock_bs.return_value = self.setup_mock_soup(f"Status {status}")
result = WebPageLoader().load(SourceContent(f"https://example.com/{status}"))
assert result.metadata["status_code"] == status
for ctype in ["text/html", "text/plain", "application/xhtml+xml"]:
mock_get.return_value = self.setup_mock_response("<html><body>Content</body></html>", content_type=ctype)
mock_bs.return_value = self.setup_mock_soup("Content")
result = WebPageLoader().load(SourceContent("https://example.com"))
assert result.metadata["content_type"] == ctype

View File

@@ -0,0 +1,137 @@
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TestWebPageLoader:
def setup_mock_response(self, text, status_code=200, content_type="text/html"):
response = Mock()
response.text = text
response.apparent_encoding = "utf-8"
response.status_code = status_code
response.headers = {"content-type": content_type}
return response
def setup_mock_soup(self, text, title=None, script_style_elements=None):
soup = Mock()
soup.get_text.return_value = text
soup.title = Mock(string=title) if title is not None else None
soup.return_value = script_style_elements or []
return soup
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_load_basic_webpage(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><head><title>Test Page</title></head><body><p>Test content</p></body></html>")
mock_bs.return_value = self.setup_mock_soup("Test content", title="Test Page")
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com"))
assert isinstance(result, LoaderResult)
assert result.content == "Test content"
assert result.metadata["title"] == "Test Page"
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get):
html = """
<html><head><title>Page with Scripts</title><style>body { color: red; }</style></head>
<body><script>console.log('test');</script><p>Visible content</p></body></html>
"""
mock_get.return_value = self.setup_mock_response(html)
scripts = [Mock(), Mock()]
styles = [Mock()]
for el in scripts + styles:
el.decompose = Mock()
mock_bs.return_value = self.setup_mock_soup("Page with Scripts Visible content", title="Page with Scripts", script_style_elements=scripts + styles)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com/with-scripts"))
assert "Visible content" in result.content
for el in scripts + styles:
el.decompose.assert_called_once()
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_text_cleaning_and_title_handling(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body><p> Messy text </p></body></html>")
mock_bs.return_value = self.setup_mock_soup("Text with extra spaces\n\n More\t text \n\n", title=None)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com/messy-text"))
assert result.content is not None
assert result.metadata["title"] == ""
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_empty_or_missing_title(self, mock_bs, mock_get):
for title in [None, ""]:
mock_get.return_value = self.setup_mock_response("<html><head><title></title></head><body>Content</body></html>")
mock_bs.return_value = self.setup_mock_soup("Content", title=title)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com"))
assert result.metadata["title"] == ""
@patch('requests.get')
def test_custom_and_default_headers(self, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body>Test</body></html>")
custom_headers = {"User-Agent": "Bot", "Authorization": "Bearer xyz", "Accept": "text/html"}
with patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') as mock_bs:
mock_bs.return_value = self.setup_mock_soup("Test")
WebPageLoader().load(SourceContent("https://example.com"), headers=custom_headers)
assert mock_get.call_args[1]['headers'] == custom_headers
@patch('requests.get')
def test_error_handling(self, mock_get):
for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]:
mock_get.side_effect = error
with pytest.raises(ValueError, match="Error loading webpage"):
WebPageLoader().load(SourceContent("https://example.com"))
@patch('requests.get')
def test_timeout_and_http_error(self, mock_get):
import requests
mock_get.side_effect = requests.Timeout("Timeout")
with pytest.raises(ValueError):
WebPageLoader().load(SourceContent("https://example.com"))
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404")
mock_get.side_effect = None
mock_get.return_value = mock_response
with pytest.raises(ValueError):
WebPageLoader().load(SourceContent("https://example.com/404"))
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_doc_id_consistency(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body>Doc</body></html>")
mock_bs.return_value = self.setup_mock_soup("Doc")
loader = WebPageLoader()
result1 = loader.load(SourceContent("https://example.com"))
result2 = loader.load(SourceContent("https://example.com"))
assert result1.doc_id == result2.doc_id
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_status_code_and_content_type(self, mock_bs, mock_get):
for status in [200, 201, 301]:
mock_get.return_value = self.setup_mock_response(f"<html><body>Status {status}</body></html>", status_code=status)
mock_bs.return_value = self.setup_mock_soup(f"Status {status}")
result = WebPageLoader().load(SourceContent(f"https://example.com/{status}"))
assert result.metadata["status_code"] == status
for ctype in ["text/html", "text/plain", "application/xhtml+xml"]:
mock_get.return_value = self.setup_mock_response("<html><body>Content</body></html>", content_type=ctype)
mock_bs.return_value = self.setup_mock_soup("Content")
result = WebPageLoader().load(SourceContent("https://example.com"))
assert result.metadata["content_type"] == ctype