diff --git a/src/crewai_tools/adapters/rag_adapter.py b/src/crewai_tools/adapters/rag_adapter.py new file mode 100644 index 000000000..78011328c --- /dev/null +++ b/src/crewai_tools/adapters/rag_adapter.py @@ -0,0 +1,41 @@ +from typing import Any, Optional + +from crewai_tools.rag.core import RAG +from crewai_tools.tools.rag.rag_tool import Adapter + + +class RAGAdapter(Adapter): + def __init__( + self, + collection_name: str = "crewai_knowledge_base", + persist_directory: Optional[str] = None, + embedding_model: str = "text-embedding-3-small", + top_k: int = 5, + embedding_api_key: Optional[str] = None, + **embedding_kwargs + ): + super().__init__() + + # Prepare embedding configuration + embedding_config = { + "api_key": embedding_api_key, + **embedding_kwargs + } + + self._adapter = RAG( + collection_name=collection_name, + persist_directory=persist_directory, + embedding_model=embedding_model, + top_k=top_k, + embedding_config=embedding_config + ) + + def query(self, question: str) -> str: + return self._adapter.query(question) + + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self._adapter.add(*args, **kwargs) diff --git a/src/crewai_tools/rag/__init__.py b/src/crewai_tools/rag/__init__.py new file mode 100644 index 000000000..8d08b2907 --- /dev/null +++ b/src/crewai_tools/rag/__init__.py @@ -0,0 +1,8 @@ +from crewai_tools.rag.core import RAG, EmbeddingService +from crewai_tools.rag.data_types import DataType + +__all__ = [ + "RAG", + "EmbeddingService", + "DataType", +] diff --git a/src/crewai_tools/rag/base_loader.py b/src/crewai_tools/rag/base_loader.py new file mode 100644 index 000000000..e38d6f8c1 --- /dev/null +++ b/src/crewai_tools/rag/base_loader.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional +from pydantic import BaseModel, Field + +from crewai_tools.rag.misc import compute_sha256 +from crewai_tools.rag.source_content import SourceContent + + +class LoaderResult(BaseModel): + content: str = Field(description="The text content of the source") + source: str = Field(description="The source of the content", default="unknown") + metadata: Dict[str, Any] = Field(description="The metadata of the source", default_factory=dict) + doc_id: str = Field(description="The id of the document") + + +class BaseLoader(ABC): + def __init__(self, config: Optional[Dict[str, Any]] = None): + self.config = config or {} + + @abstractmethod + def load(self, content: SourceContent, **kwargs) -> LoaderResult: + ... + + def generate_doc_id(self, source_ref: str | None = None, content: str | None = None) -> str: + """ + Generate a unique document id based on the source reference and content. + If the source reference is not provided, the content is used as the source reference. + If the content is not provided, the source reference is used as the content. + If both are provided, the source reference is used as the content. + + Both are optional because the TEXT content type does not have a source reference. In this case, the content is used as the source reference. + """ + + source_ref = source_ref or "" + content = content or "" + + return compute_sha256(source_ref + content) diff --git a/src/crewai_tools/rag/chunkers/__init__.py b/src/crewai_tools/rag/chunkers/__init__.py new file mode 100644 index 000000000..f48483391 --- /dev/null +++ b/src/crewai_tools/rag/chunkers/__init__.py @@ -0,0 +1,15 @@ +from crewai_tools.rag.chunkers.base_chunker import BaseChunker +from crewai_tools.rag.chunkers.default_chunker import DefaultChunker +from crewai_tools.rag.chunkers.text_chunker import TextChunker, DocxChunker, MdxChunker +from crewai_tools.rag.chunkers.structured_chunker import CsvChunker, JsonChunker, XmlChunker + +__all__ = [ + "BaseChunker", + "DefaultChunker", + "TextChunker", + "DocxChunker", + "MdxChunker", + "CsvChunker", + "JsonChunker", + "XmlChunker", +] diff --git a/src/crewai_tools/rag/chunkers/base_chunker.py b/src/crewai_tools/rag/chunkers/base_chunker.py new file mode 100644 index 000000000..deafbfc7a --- /dev/null +++ b/src/crewai_tools/rag/chunkers/base_chunker.py @@ -0,0 +1,167 @@ +from typing import List, Optional +import re + +class RecursiveCharacterTextSplitter: + """ + A text splitter that recursively splits text based on a hierarchy of separators. + """ + + def __init__( + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + ): + """ + Initialize the RecursiveCharacterTextSplitter. + + Args: + chunk_size: Maximum size of each chunk + chunk_overlap: Number of characters to overlap between chunks + separators: List of separators to use for splitting (in order of preference) + keep_separator: Whether to keep the separator in the split text + """ + if chunk_overlap >= chunk_size: + raise ValueError(f"Chunk overlap ({chunk_overlap}) cannot be >= chunk size ({chunk_size})") + + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._keep_separator = keep_separator + + self._separators = separators or [ + "\n\n", + "\n", + " ", + "", + ] + + def split_text(self, text: str) -> List[str]: + return self._split_text(text, self._separators) + + def _split_text(self, text: str, separators: List[str]) -> List[str]: + separator = separators[-1] + new_separators = [] + + for i, sep in enumerate(separators): + if sep == "": + separator = sep + break + if re.search(re.escape(sep), text): + separator = sep + new_separators = separators[i + 1:] + break + + splits = self._split_text_with_separator(text, separator) + + good_splits = [] + + for split in splits: + if len(split) < self._chunk_size: + good_splits.append(split) + else: + if new_separators: + other_info = self._split_text(split, new_separators) + good_splits.extend(other_info) + else: + good_splits.extend(self._split_by_characters(split)) + + return self._merge_splits(good_splits, separator) + + def _split_text_with_separator(self, text: str, separator: str) -> List[str]: + if separator == "": + return list(text) + + if self._keep_separator and separator in text: + parts = text.split(separator) + splits = [] + + for i, part in enumerate(parts): + if i == 0: + splits.append(part) + elif i == len(parts) - 1: + if part: + splits.append(separator + part) + else: + if part: + splits.append(separator + part) + else: + if splits: + splits[-1] += separator + + return [s for s in splits if s] + else: + return text.split(separator) + + def _split_by_characters(self, text: str) -> List[str]: + chunks = [] + for i in range(0, len(text), self._chunk_size): + chunks.append(text[i:i + self._chunk_size]) + return chunks + + def _merge_splits(self, splits: List[str], separator: str) -> List[str]: + """Merge splits into chunks with proper overlap.""" + docs = [] + current_doc = [] + total = 0 + + for split in splits: + split_len = len(split) + + if total + split_len > self._chunk_size and current_doc: + if separator == "": + doc = "".join(current_doc) + else: + doc = separator.join(current_doc) + + if doc: + docs.append(doc) + + # Handle overlap by keeping some of the previous content + while total > self._chunk_overlap and len(current_doc) > 1: + removed = current_doc.pop(0) + total -= len(removed) + if separator != "": + total -= len(separator) + + current_doc.append(split) + total += split_len + if separator != "" and len(current_doc) > 1: + total += len(separator) + + if current_doc: + if separator == "": + doc = "".join(current_doc) + else: + doc = separator.join(current_doc) + + if doc: + docs.append(doc) + + return docs + +class BaseChunker: + def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True): + """ + Initialize the Chunker + + Args: + chunk_size: Maximum size of each chunk + chunk_overlap: Number of characters to overlap between chunks + separators: List of separators to use for splitting + keep_separator: Whether to keep separators in the chunks + """ + + self._splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + separators=separators, + keep_separator=keep_separator, + ) + + + def chunk(self, text: str) -> List[str]: + if not text or not text.strip(): + return [] + + return self._splitter.split_text(text) diff --git a/src/crewai_tools/rag/chunkers/default_chunker.py b/src/crewai_tools/rag/chunkers/default_chunker.py new file mode 100644 index 000000000..0d0ec6935 --- /dev/null +++ b/src/crewai_tools/rag/chunkers/default_chunker.py @@ -0,0 +1,6 @@ +from crewai_tools.rag.chunkers.base_chunker import BaseChunker +from typing import List, Optional + +class DefaultChunker(BaseChunker): + def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 20, separators: Optional[List[str]] = None, keep_separator: bool = True): + super().__init__(chunk_size, chunk_overlap, separators, keep_separator) diff --git a/src/crewai_tools/rag/chunkers/structured_chunker.py b/src/crewai_tools/rag/chunkers/structured_chunker.py new file mode 100644 index 000000000..483f92588 --- /dev/null +++ b/src/crewai_tools/rag/chunkers/structured_chunker.py @@ -0,0 +1,49 @@ +from crewai_tools.rag.chunkers.base_chunker import BaseChunker +from typing import List, Optional + + +class CsvChunker(BaseChunker): + def __init__(self, chunk_size: int = 1200, chunk_overlap: int = 100, separators: Optional[List[str]] = None, keep_separator: bool = True): + if separators is None: + separators = [ + "\nRow ", # Row boundaries (from CSVLoader format) + "\n", # Line breaks + " | ", # Column separators + ", ", # Comma separators + " ", # Word breaks + "", # Character level + ] + super().__init__(chunk_size, chunk_overlap, separators, keep_separator) + + +class JsonChunker(BaseChunker): + def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True): + if separators is None: + separators = [ + "\n\n", # Object/array boundaries + "\n", # Line breaks + "},", # Object endings + "],", # Array endings + ", ", # Property separators + ": ", # Key-value separators + " ", # Word breaks + "", # Character level + ] + super().__init__(chunk_size, chunk_overlap, separators, keep_separator) + + +class XmlChunker(BaseChunker): + def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True): + if separators is None: + separators = [ + "\n\n", # Element boundaries + "\n", # Line breaks + ">", # Tag endings + ". ", # Sentence endings (for text content) + "! ", # Exclamation endings + "? ", # Question endings + ", ", # Comma separators + " ", # Word breaks + "", # Character level + ] + super().__init__(chunk_size, chunk_overlap, separators, keep_separator) diff --git a/src/crewai_tools/rag/chunkers/text_chunker.py b/src/crewai_tools/rag/chunkers/text_chunker.py new file mode 100644 index 000000000..2e76df8ab --- /dev/null +++ b/src/crewai_tools/rag/chunkers/text_chunker.py @@ -0,0 +1,59 @@ +from crewai_tools.rag.chunkers.base_chunker import BaseChunker +from typing import List, Optional + + +class TextChunker(BaseChunker): + def __init__(self, chunk_size: int = 1500, chunk_overlap: int = 150, separators: Optional[List[str]] = None, keep_separator: bool = True): + if separators is None: + separators = [ + "\n\n\n", # Multiple line breaks (sections) + "\n\n", # Paragraph breaks + "\n", # Line breaks + ". ", # Sentence endings + "! ", # Exclamation endings + "? ", # Question endings + "; ", # Semicolon breaks + ", ", # Comma breaks + " ", # Word breaks + "", # Character level + ] + super().__init__(chunk_size, chunk_overlap, separators, keep_separator) + + +class DocxChunker(BaseChunker): + def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True): + if separators is None: + separators = [ + "\n\n\n", # Multiple line breaks (major sections) + "\n\n", # Paragraph breaks + "\n", # Line breaks + ". ", # Sentence endings + "! ", # Exclamation endings + "? ", # Question endings + "; ", # Semicolon breaks + ", ", # Comma breaks + " ", # Word breaks + "", # Character level + ] + super().__init__(chunk_size, chunk_overlap, separators, keep_separator) + + +class MdxChunker(BaseChunker): + def __init__(self, chunk_size: int = 3000, chunk_overlap: int = 300, separators: Optional[List[str]] = None, keep_separator: bool = True): + if separators is None: + separators = [ + "\n## ", # H2 headers (major sections) + "\n### ", # H3 headers (subsections) + "\n#### ", # H4 headers (sub-subsections) + "\n\n", # Paragraph breaks + "\n```", # Code block boundaries + "\n", # Line breaks + ". ", # Sentence endings + "! ", # Exclamation endings + "? ", # Question endings + "; ", # Semicolon breaks + ", ", # Comma breaks + " ", # Word breaks + "", # Character level + ] + super().__init__(chunk_size, chunk_overlap, separators, keep_separator) diff --git a/src/crewai_tools/rag/chunkers/web_chunker.py b/src/crewai_tools/rag/chunkers/web_chunker.py new file mode 100644 index 000000000..2712a6c69 --- /dev/null +++ b/src/crewai_tools/rag/chunkers/web_chunker.py @@ -0,0 +1,20 @@ +from crewai_tools.rag.chunkers.base_chunker import BaseChunker +from typing import List, Optional + + +class WebsiteChunker(BaseChunker): + def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True): + if separators is None: + separators = [ + "\n\n\n", # Major section breaks + "\n\n", # Paragraph breaks + "\n", # Line breaks + ". ", # Sentence endings + "! ", # Exclamation endings + "? ", # Question endings + "; ", # Semicolon breaks + ", ", # Comma breaks + " ", # Word breaks + "", # Character level + ] + super().__init__(chunk_size, chunk_overlap, separators, keep_separator) diff --git a/src/crewai_tools/rag/core.py b/src/crewai_tools/rag/core.py new file mode 100644 index 000000000..0aa4b666c --- /dev/null +++ b/src/crewai_tools/rag/core.py @@ -0,0 +1,232 @@ +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Union +from uuid import uuid4 + +import chromadb +import litellm +from pydantic import BaseModel, Field, PrivateAttr + +from crewai_tools.tools.rag.rag_tool import Adapter +from crewai_tools.rag.data_types import DataType +from crewai_tools.rag.base_loader import BaseLoader +from crewai_tools.rag.chunkers.base_chunker import BaseChunker +from crewai_tools.rag.source_content import SourceContent +from crewai_tools.rag.misc import compute_sha256 + +logger = logging.getLogger(__name__) + + +class EmbeddingService: + def __init__(self, model: str = "text-embedding-3-small", **kwargs): + self.model = model + self.kwargs = kwargs + + def embed_text(self, text: str) -> List[float]: + try: + response = litellm.embedding( + model=self.model, + input=[text], + **self.kwargs + ) + return response.data[0]['embedding'] + except Exception as e: + logger.error(f"Error generating embedding: {e}") + raise + + def embed_batch(self, texts: List[str]) -> List[List[float]]: + if not texts: + return [] + + try: + response = litellm.embedding( + model=self.model, + input=texts, + **self.kwargs + ) + return [data['embedding'] for data in response.data] + except Exception as e: + logger.error(f"Error generating batch embeddings: {e}") + raise + + +class Document(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + content: str + metadata: Dict[str, Any] = Field(default_factory=dict) + data_type: DataType = DataType.TEXT + source: Optional[str] = None + + +class RAG(Adapter): + collection_name: str = "crewai_knowledge_base" + persist_directory: Optional[str] = None + embedding_model: str = "text-embedding-3-large" + summarize: bool = False + top_k: int = 5 + embedding_config: Dict[str, Any] = Field(default_factory=dict) + + _client: Any = PrivateAttr() + _collection: Any = PrivateAttr() + _embedding_service: EmbeddingService = PrivateAttr() + + def model_post_init(self, __context: Any) -> None: + try: + if self.persist_directory: + self._client = chromadb.PersistentClient(path=self.persist_directory) + else: + self._client = chromadb.Client() + + self._collection = self._client.get_or_create_collection( + name=self.collection_name, + metadata={"hnsw:space": "cosine", "description": "CrewAI Knowledge Base"} + ) + + self._embedding_service = EmbeddingService(model=self.embedding_model, **self.embedding_config) + except Exception as e: + logger.error(f"Failed to initialize ChromaDB: {e}") + raise + + super().model_post_init(__context) + + def add( + self, + content: str | Path, + data_type: Optional[Union[str, DataType]] = None, + metadata: Optional[Dict[str, Any]] = None, + loader: Optional[BaseLoader] = None, + chunker: Optional[BaseChunker] = None, + **kwargs: Any + ) -> None: + source_content = SourceContent(content) + + data_type = self._get_data_type(data_type=data_type, content=source_content) + + if not loader: + loader = data_type.get_loader() + + if not chunker: + chunker = data_type.get_chunker() + + loader_result = loader.load(source_content) + doc_id = loader_result.doc_id + + existing_doc = self._collection.get(where={"source": source_content.source_ref}, limit=1) + existing_doc_id = existing_doc and existing_doc['metadatas'][0]['doc_id'] if existing_doc['metadatas'] else None + + if existing_doc_id == doc_id: + logger.warning(f"Document with source {loader_result.source} already exists") + return + + # Document with same source ref does exists but the content has changed, deleting the oldest reference + if existing_doc_id and existing_doc_id != loader_result.doc_id: + logger.warning(f"Deleting old document with doc_id {existing_doc_id}") + self._collection.delete(where={"doc_id": existing_doc_id}) + + documents = [] + + chunks = chunker.chunk(loader_result.content) + for i, chunk in enumerate(chunks): + doc_metadata = (metadata or {}).copy() + doc_metadata['chunk_index'] = i + documents.append(Document( + id=compute_sha256(chunk), + content=chunk, + metadata=doc_metadata, + data_type=data_type, + source=loader_result.source + )) + + if not documents: + logger.warning("No documents to add") + return + + contents = [doc.content for doc in documents] + try: + embeddings = self._embedding_service.embed_batch(contents) + except Exception as e: + logger.error(f"Failed to generate embeddings: {e}") + return + + ids = [doc.id for doc in documents] + metadatas = [] + + for doc in documents: + doc_metadata = doc.metadata.copy() + doc_metadata.update({ + "data_type": doc.data_type.value, + "source": doc.source, + "doc_id": doc_id + }) + metadatas.append(doc_metadata) + + try: + self._collection.add( + ids=ids, + embeddings=embeddings, + documents=contents, + metadatas=metadatas, + ) + logger.info(f"Added {len(documents)} documents to knowledge base") + except Exception as e: + logger.error(f"Failed to add documents to ChromaDB: {e}") + + def query(self, question: str, where: Optional[Dict[str, Any]] = None) -> str: + try: + question_embedding = self._embedding_service.embed_text(question) + + results = self._collection.query( + query_embeddings=[question_embedding], + n_results=self.top_k, + where=where, + include=["documents", "metadatas", "distances"] + ) + + if not results or not results.get("documents") or not results["documents"][0]: + return "No relevant content found." + + documents = results["documents"][0] + metadatas = results.get("metadatas", [None])[0] or [] + distances = results.get("distances", [None])[0] or [] + + # Return sources with relevance scores + formatted_results = [] + for i, doc in enumerate(documents): + metadata = metadatas[i] if i < len(metadatas) else {} + distance = distances[i] if i < len(distances) else 1.0 + source = metadata.get("source", "unknown") if metadata else "unknown" + score = 1 - distance if distance is not None else 0 # Convert distance to similarity + formatted_results.append(f"[Source: {source}, Relevance: {score:.3f}]\n{doc}") + + return "\n\n".join(formatted_results) + except Exception as e: + logger.error(f"Query failed: {e}") + return f"Error querying knowledge base: {e}" + + def delete_collection(self) -> None: + try: + self._client.delete_collection(self.collection_name) + logger.info(f"Deleted collection: {self.collection_name}") + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + + def get_collection_info(self) -> Dict[str, Any]: + try: + count = self._collection.count() + return { + "name": self.collection_name, + "count": count, + "embedding_model": self.embedding_model + } + except Exception as e: + logger.error(f"Failed to get collection info: {e}") + return {"error": str(e)} + + def _get_data_type(self, content: SourceContent, data_type: str | DataType | None = None) -> DataType: + try: + if isinstance(data_type, str): + return DataType(data_type) + except Exception as e: + pass + + return content.data_type diff --git a/src/crewai_tools/rag/data_types.py b/src/crewai_tools/rag/data_types.py new file mode 100644 index 000000000..d2d265cce --- /dev/null +++ b/src/crewai_tools/rag/data_types.py @@ -0,0 +1,137 @@ +from enum import Enum +from pathlib import Path +from urllib.parse import urlparse +import os +from crewai_tools.rag.chunkers.base_chunker import BaseChunker +from crewai_tools.rag.base_loader import BaseLoader + +class DataType(str, Enum): + PDF_FILE = "pdf_file" + TEXT_FILE = "text_file" + CSV = "csv" + JSON = "json" + XML = "xml" + DOCX = "docx" + MDX = "mdx" + + # Database types + MYSQL = "mysql" + POSTGRES = "postgres" + + # Repository types + GITHUB = "github" + DIRECTORY = "directory" + + # Web types + WEBSITE = "website" + DOCS_SITE = "docs_site" + + # Raw types + TEXT = "text" + + + def get_chunker(self) -> BaseChunker: + from importlib import import_module + + chunkers = { + DataType.TEXT_FILE: ("text_chunker", "TextChunker"), + DataType.TEXT: ("text_chunker", "TextChunker"), + DataType.DOCX: ("text_chunker", "DocxChunker"), + DataType.MDX: ("text_chunker", "MdxChunker"), + + # Structured formats + DataType.CSV: ("structured_chunker", "CsvChunker"), + DataType.JSON: ("structured_chunker", "JsonChunker"), + DataType.XML: ("structured_chunker", "XmlChunker"), + + DataType.WEBSITE: ("web_chunker", "WebsiteChunker"), + } + + module_name, class_name = chunkers.get(self, ("default_chunker", "DefaultChunker")) + module_path = f"crewai_tools.rag.chunkers.{module_name}" + + try: + module = import_module(module_path) + return getattr(module, class_name)() + except Exception as e: + raise ValueError(f"Error loading chunker for {self}: {e}") + + def get_loader(self) -> BaseLoader: + from importlib import import_module + + loaders = { + DataType.TEXT_FILE: ("text_loader", "TextFileLoader"), + DataType.TEXT: ("text_loader", "TextLoader"), + DataType.XML: ("xml_loader", "XMLLoader"), + DataType.WEBSITE: ("webpage_loader", "WebPageLoader"), + DataType.MDX: ("mdx_loader", "MDXLoader"), + DataType.JSON: ("json_loader", "JSONLoader"), + DataType.DOCX: ("docx_loader", "DOCXLoader"), + DataType.CSV: ("csv_loader", "CSVLoader"), + DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"), + } + + module_name, class_name = loaders.get(self, ("text_loader", "TextLoader")) + module_path = f"crewai_tools.rag.loaders.{module_name}" + try: + module = import_module(module_path) + return getattr(module, class_name)() + except Exception as e: + raise ValueError(f"Error loading loader for {self}: {e}") + +class DataTypes: + @staticmethod + def from_content(content: str | Path | None = None) -> DataType: + if content is None: + return DataType.TEXT + + if isinstance(content, Path): + content = str(content) + + is_url = False + if isinstance(content, str): + try: + url = urlparse(content) + is_url = (url.scheme and url.netloc) or url.scheme == "file" + except Exception: + pass + + def get_file_type(path: str) -> DataType | None: + mapping = { + ".pdf": DataType.PDF_FILE, + ".csv": DataType.CSV, + ".mdx": DataType.MDX, + ".md": DataType.MDX, + ".docx": DataType.DOCX, + ".json": DataType.JSON, + ".xml": DataType.XML, + ".txt": DataType.TEXT_FILE, + } + for ext, dtype in mapping.items(): + if path.endswith(ext): + return dtype + return None + + if is_url: + dtype = get_file_type(url.path) + if dtype: + return dtype + + if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"): + return DataType.DOCS_SITE + if "github.com" in url.netloc: + return DataType.GITHUB + + return DataType.WEBSITE + + if os.path.isfile(content): + dtype = get_file_type(content) + if dtype: + return dtype + + if os.path.exists(content): + return DataType.TEXT_FILE + elif os.path.isdir(content): + return DataType.DIRECTORY + + return DataType.TEXT diff --git a/src/crewai_tools/rag/loaders/__init__.py b/src/crewai_tools/rag/loaders/__init__.py new file mode 100644 index 000000000..503651468 --- /dev/null +++ b/src/crewai_tools/rag/loaders/__init__.py @@ -0,0 +1,20 @@ +from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader +from crewai_tools.rag.loaders.xml_loader import XMLLoader +from crewai_tools.rag.loaders.webpage_loader import WebPageLoader +from crewai_tools.rag.loaders.mdx_loader import MDXLoader +from crewai_tools.rag.loaders.json_loader import JSONLoader +from crewai_tools.rag.loaders.docx_loader import DOCXLoader +from crewai_tools.rag.loaders.csv_loader import CSVLoader +from crewai_tools.rag.loaders.directory_loader import DirectoryLoader + +__all__ = [ + "TextFileLoader", + "TextLoader", + "XMLLoader", + "WebPageLoader", + "MDXLoader", + "JSONLoader", + "DOCXLoader", + "CSVLoader", + "DirectoryLoader", +] diff --git a/src/crewai_tools/rag/loaders/csv_loader.py b/src/crewai_tools/rag/loaders/csv_loader.py new file mode 100644 index 000000000..e389123a7 --- /dev/null +++ b/src/crewai_tools/rag/loaders/csv_loader.py @@ -0,0 +1,72 @@ +import csv +from io import StringIO + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class CSVLoader(BaseLoader): + def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: + source_ref = source_content.source_ref + + content_str = source_content.source + if source_content.is_url(): + content_str = self._load_from_url(content_str, kwargs) + elif source_content.path_exists(): + content_str = self._load_from_file(content_str) + + return self._parse_csv(content_str, source_ref) + + + def _load_from_url(self, url: str, kwargs: dict) -> str: + import requests + + headers = kwargs.get("headers", { + "Accept": "text/csv, application/csv, text/plain", + "User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)" + }) + + try: + response = requests.get(url, headers=headers, timeout=30) + response.raise_for_status() + return response.text + except Exception as e: + raise ValueError(f"Error fetching CSV from URL {url}: {str(e)}") + + def _load_from_file(self, path: str) -> str: + with open(path, "r", encoding="utf-8") as file: + return file.read() + + def _parse_csv(self, content: str, source_ref: str) -> LoaderResult: + try: + csv_reader = csv.DictReader(StringIO(content)) + + text_parts = [] + headers = csv_reader.fieldnames + + if headers: + text_parts.append("Headers: " + " | ".join(headers)) + text_parts.append("-" * 50) + + for row_num, row in enumerate(csv_reader, 1): + row_text = " | ".join([f"{k}: {v}" for k, v in row.items() if v]) + text_parts.append(f"Row {row_num}: {row_text}") + + text = "\n".join(text_parts) + + metadata = { + "format": "csv", + "columns": headers, + "rows": len(text_parts) - 2 if headers else 0 + } + + except Exception as e: + text = content + metadata = {"format": "csv", "parse_error": str(e)} + + return LoaderResult( + content=text, + source=source_ref, + metadata=metadata, + doc_id=self.generate_doc_id(source_ref=source_ref, content=text) + ) diff --git a/src/crewai_tools/rag/loaders/directory_loader.py b/src/crewai_tools/rag/loaders/directory_loader.py new file mode 100644 index 000000000..7bc5f298b --- /dev/null +++ b/src/crewai_tools/rag/loaders/directory_loader.py @@ -0,0 +1,142 @@ +import os +from pathlib import Path +from typing import List + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class DirectoryLoader(BaseLoader): + def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: + """ + Load and process all files from a directory recursively. + + Args: + source: Directory path or URL to a directory listing + **kwargs: Additional options: + - recursive: bool (default True) - Whether to search recursively + - include_extensions: list - Only include files with these extensions + - exclude_extensions: list - Exclude files with these extensions + - max_files: int - Maximum number of files to process + """ + source_ref = source_content.source_ref + + if source_content.is_url(): + raise ValueError("URL directory loading is not supported. Please provide a local directory path.") + + if not os.path.exists(source_ref): + raise FileNotFoundError(f"Directory does not exist: {source_ref}") + + if not os.path.isdir(source_ref): + raise ValueError(f"Path is not a directory: {source_ref}") + + return self._process_directory(source_ref, kwargs) + + def _process_directory(self, dir_path: str, kwargs: dict) -> LoaderResult: + recursive = kwargs.get("recursive", True) + include_extensions = kwargs.get("include_extensions", None) + exclude_extensions = kwargs.get("exclude_extensions", None) + max_files = kwargs.get("max_files", None) + + files = self._find_files(dir_path, recursive, include_extensions, exclude_extensions) + + if max_files and len(files) > max_files: + files = files[:max_files] + + all_contents = [] + processed_files = [] + errors = [] + + for file_path in files: + try: + result = self._process_single_file(file_path) + if result: + all_contents.append(f"=== File: {file_path} ===\n{result.content}") + processed_files.append({ + "path": file_path, + "metadata": result.metadata, + "source": result.source + }) + except Exception as e: + error_msg = f"Error processing {file_path}: {str(e)}" + errors.append(error_msg) + all_contents.append(f"=== File: {file_path} (ERROR) ===\n{error_msg}") + + combined_content = "\n\n".join(all_contents) + + metadata = { + "format": "directory", + "directory_path": dir_path, + "total_files": len(files), + "processed_files": len(processed_files), + "errors": len(errors), + "file_details": processed_files, + "error_details": errors + } + + return LoaderResult( + content=combined_content, + source=dir_path, + metadata=metadata, + doc_id=self.generate_doc_id(source_ref=dir_path, content=combined_content) + ) + + def _find_files(self, dir_path: str, recursive: bool, include_ext: List[str] | None = None, exclude_ext: List[str] | None = None) -> List[str]: + """Find all files in directory matching criteria.""" + files = [] + + if recursive: + for root, dirs, filenames in os.walk(dir_path): + dirs[:] = [d for d in dirs if not d.startswith('.')] + + for filename in filenames: + if self._should_include_file(filename, include_ext, exclude_ext): + files.append(os.path.join(root, filename)) + else: + try: + for item in os.listdir(dir_path): + item_path = os.path.join(dir_path, item) + if os.path.isfile(item_path) and self._should_include_file(item, include_ext, exclude_ext): + files.append(item_path) + except PermissionError: + pass + + return sorted(files) + + def _should_include_file(self, filename: str, include_ext: List[str] = None, exclude_ext: List[str] = None) -> bool: + """Determine if a file should be included based on criteria.""" + if filename.startswith('.'): + return False + + _, ext = os.path.splitext(filename.lower()) + + if include_ext: + if ext not in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in include_ext]: + return False + + if exclude_ext: + if ext in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in exclude_ext]: + return False + + return True + + def _process_single_file(self, file_path: str) -> LoaderResult: + from crewai_tools.rag.data_types import DataTypes + + data_type = DataTypes.from_content(Path(file_path)) + + loader = data_type.get_loader() + + result = loader.load(SourceContent(file_path)) + + if result.metadata is None: + result.metadata = {} + + result.metadata.update({ + "file_path": file_path, + "file_size": os.path.getsize(file_path), + "data_type": str(data_type), + "loader_type": loader.__class__.__name__ + }) + + return result diff --git a/src/crewai_tools/rag/loaders/docx_loader.py b/src/crewai_tools/rag/loaders/docx_loader.py new file mode 100644 index 000000000..2f5df23af --- /dev/null +++ b/src/crewai_tools/rag/loaders/docx_loader.py @@ -0,0 +1,72 @@ +import os +import tempfile + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class DOCXLoader(BaseLoader): + def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: + try: + from docx import Document as DocxDocument + except ImportError: + raise ImportError("python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]") + + source_ref = source_content.source_ref + + if source_content.is_url(): + temp_file = self._download_from_url(source_ref, kwargs) + try: + return self._load_from_file(temp_file, source_ref, DocxDocument) + finally: + os.unlink(temp_file) + elif source_content.path_exists(): + return self._load_from_file(source_ref, source_ref, DocxDocument) + else: + raise ValueError(f"Source must be a valid file path or URL, got: {source_content.source}") + + def _download_from_url(self, url: str, kwargs: dict) -> str: + import requests + + headers = kwargs.get("headers", { + "Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "User-Agent": "Mozilla/5.0 (compatible; crewai-tools DOCXLoader)" + }) + + try: + response = requests.get(url, headers=headers, timeout=30) + response.raise_for_status() + + # Create temporary file to save the DOCX content + with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as temp_file: + temp_file.write(response.content) + return temp_file.name + except Exception as e: + raise ValueError(f"Error fetching DOCX from URL {url}: {str(e)}") + + def _load_from_file(self, file_path: str, source_ref: str, DocxDocument) -> LoaderResult: + try: + doc = DocxDocument(file_path) + + text_parts = [] + for paragraph in doc.paragraphs: + if paragraph.text.strip(): + text_parts.append(paragraph.text) + + content = "\n".join(text_parts) + + metadata = { + "format": "docx", + "paragraphs": len(doc.paragraphs), + "tables": len(doc.tables) + } + + return LoaderResult( + content=content, + source=source_ref, + metadata=metadata, + doc_id=self.generate_doc_id(source_ref=source_ref, content=content) + ) + + except Exception as e: + raise ValueError(f"Error loading DOCX file: {str(e)}") diff --git a/src/crewai_tools/rag/loaders/json_loader.py b/src/crewai_tools/rag/loaders/json_loader.py new file mode 100644 index 000000000..6efab393a --- /dev/null +++ b/src/crewai_tools/rag/loaders/json_loader.py @@ -0,0 +1,69 @@ +import json + +from crewai_tools.rag.source_content import SourceContent +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult + + +class JSONLoader(BaseLoader): + def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: + source_ref = source_content.source_ref + content = source_content.source + + if source_content.is_url(): + content = self._load_from_url(source_ref, kwargs) + elif source_content.path_exists(): + content = self._load_from_file(source_ref) + + return self._parse_json(content, source_ref) + + def _load_from_url(self, url: str, kwargs: dict) -> str: + import requests + + headers = kwargs.get("headers", { + "Accept": "application/json", + "User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)" + }) + + try: + response = requests.get(url, headers=headers, timeout=30) + response.raise_for_status() + return response.text if not self._is_json_response(response) else json.dumps(response.json(), indent=2) + except Exception as e: + raise ValueError(f"Error fetching JSON from URL {url}: {str(e)}") + + def _is_json_response(self, response) -> bool: + try: + response.json() + return True + except ValueError: + return False + + def _load_from_file(self, path: str) -> str: + with open(path, "r", encoding="utf-8") as file: + return file.read() + + def _parse_json(self, content: str, source_ref: str) -> LoaderResult: + try: + data = json.loads(content) + if isinstance(data, dict): + text = "\n".join(f"{k}: {json.dumps(v, indent=0)}" for k, v in data.items()) + elif isinstance(data, list): + text = "\n".join(json.dumps(item, indent=0) for item in data) + else: + text = json.dumps(data, indent=0) + + metadata = { + "format": "json", + "type": type(data).__name__, + "size": len(data) if isinstance(data, (list, dict)) else 1 + } + except json.JSONDecodeError as e: + text = content + metadata = {"format": "json", "parse_error": str(e)} + + return LoaderResult( + content=text, + source=source_ref, + metadata=metadata, + doc_id=self.generate_doc_id(source_ref=source_ref, content=text) + ) diff --git a/src/crewai_tools/rag/loaders/mdx_loader.py b/src/crewai_tools/rag/loaders/mdx_loader.py new file mode 100644 index 000000000..6da9dc896 --- /dev/null +++ b/src/crewai_tools/rag/loaders/mdx_loader.py @@ -0,0 +1,59 @@ +import re + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + +class MDXLoader(BaseLoader): + def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: + source_ref = source_content.source_ref + content = source_content.source + + if source_content.is_url(): + content = self._load_from_url(source_ref, kwargs) + elif source_content.path_exists(): + content = self._load_from_file(source_ref) + + return self._parse_mdx(content, source_ref) + + def _load_from_url(self, url: str, kwargs: dict) -> str: + import requests + + headers = kwargs.get("headers", { + "Accept": "text/markdown, text/x-markdown, text/plain", + "User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)" + }) + + try: + response = requests.get(url, headers=headers, timeout=30) + response.raise_for_status() + return response.text + except Exception as e: + raise ValueError(f"Error fetching MDX from URL {url}: {str(e)}") + + def _load_from_file(self, path: str) -> str: + with open(path, "r", encoding="utf-8") as file: + return file.read() + + def _parse_mdx(self, content: str, source_ref: str) -> LoaderResult: + cleaned_content = content + + # Remove import statements + cleaned_content = re.sub(r'^import\s+.*?\n', '', cleaned_content, flags=re.MULTILINE) + + # Remove export statements + cleaned_content = re.sub(r'^export\s+.*?(?:\n|$)', '', cleaned_content, flags=re.MULTILINE) + + # Remove JSX tags (simple approach) + cleaned_content = re.sub(r'<[^>]+>', '', cleaned_content) + + # Clean up extra whitespace + cleaned_content = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_content) + cleaned_content = cleaned_content.strip() + + metadata = {"format": "mdx"} + return LoaderResult( + content=cleaned_content, + source=source_ref, + metadata=metadata, + doc_id=self.generate_doc_id(source_ref=source_ref, content=cleaned_content) + ) diff --git a/src/crewai_tools/rag/loaders/text_loader.py b/src/crewai_tools/rag/loaders/text_loader.py new file mode 100644 index 000000000..a97cf29f4 --- /dev/null +++ b/src/crewai_tools/rag/loaders/text_loader.py @@ -0,0 +1,28 @@ + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class TextFileLoader(BaseLoader): + def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: + source_ref = source_content.source_ref + if not source_content.path_exists(): + raise FileNotFoundError(f"The following file does not exist: {source_content.source}") + + with open(source_content.source, "r", encoding="utf-8") as file: + content = file.read() + + return LoaderResult( + content=content, + source=source_ref, + doc_id=self.generate_doc_id(source_ref=source_ref, content=content) + ) + + +class TextLoader(BaseLoader): + def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: + return LoaderResult( + content=source_content.source, + source=source_content.source_ref, + doc_id=self.generate_doc_id(content=source_content.source) + ) diff --git a/src/crewai_tools/rag/loaders/webpage_loader.py b/src/crewai_tools/rag/loaders/webpage_loader.py new file mode 100644 index 000000000..4fcb1e0c4 --- /dev/null +++ b/src/crewai_tools/rag/loaders/webpage_loader.py @@ -0,0 +1,47 @@ +import re +import requests +from bs4 import BeautifulSoup + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + +class WebPageLoader(BaseLoader): + def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: + url = source_content.source + headers = kwargs.get("headers", { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36", + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", + "Accept-Language": "en-US,en;q=0.9", + }) + + try: + response = requests.get(url, timeout=15, headers=headers) + response.encoding = response.apparent_encoding + + soup = BeautifulSoup(response.text, "html.parser") + + for script in soup(["script", "style"]): + script.decompose() + + text = soup.get_text(" ") + text = re.sub("[ \t]+", " ", text) + text = re.sub("\\s+\n\\s+", "\n", text) + text = text.strip() + + title = soup.title.string.strip() if soup.title and soup.title.string else "" + metadata = { + "url": url, + "title": title, + "status_code": response.status_code, + "content_type": response.headers.get("content-type", "") + } + + return LoaderResult( + content=text, + source=url, + metadata=metadata, + doc_id=self.generate_doc_id(source_ref=url, content=text) + ) + + except Exception as e: + raise ValueError(f"Error loading webpage {url}: {str(e)}") diff --git a/src/crewai_tools/rag/loaders/xml_loader.py b/src/crewai_tools/rag/loaders/xml_loader.py new file mode 100644 index 000000000..ffafdb9d9 --- /dev/null +++ b/src/crewai_tools/rag/loaders/xml_loader.py @@ -0,0 +1,61 @@ +import os +import xml.etree.ElementTree as ET + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + +class XMLLoader(BaseLoader): + def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: + source_ref = source_content.source_ref + content = source_content.source + + if source_content.is_url(): + content = self._load_from_url(source_ref, kwargs) + elif os.path.exists(source_ref): + content = self._load_from_file(source_ref) + + return self._parse_xml(content, source_ref) + + def _load_from_url(self, url: str, kwargs: dict) -> str: + import requests + + headers = kwargs.get("headers", { + "Accept": "application/xml, text/xml, text/plain", + "User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)" + }) + + try: + response = requests.get(url, headers=headers, timeout=30) + response.raise_for_status() + return response.text + except Exception as e: + raise ValueError(f"Error fetching XML from URL {url}: {str(e)}") + + def _load_from_file(self, path: str) -> str: + with open(path, "r", encoding="utf-8") as file: + return file.read() + + def _parse_xml(self, content: str, source_ref: str) -> LoaderResult: + try: + if content.strip().startswith('<'): + root = ET.fromstring(content) + else: + root = ET.parse(source_ref).getroot() + + text_parts = [] + for text_content in root.itertext(): + if text_content and text_content.strip(): + text_parts.append(text_content.strip()) + + text = "\n".join(text_parts) + metadata = {"format": "xml", "root_tag": root.tag} + except ET.ParseError as e: + text = content + metadata = {"format": "xml", "parse_error": str(e)} + + return LoaderResult( + content=text, + source=source_ref, + metadata=metadata, + doc_id=self.generate_doc_id(source_ref=source_ref, content=text) + ) diff --git a/src/crewai_tools/rag/misc.py b/src/crewai_tools/rag/misc.py new file mode 100644 index 000000000..5b95f804e --- /dev/null +++ b/src/crewai_tools/rag/misc.py @@ -0,0 +1,4 @@ +import hashlib + +def compute_sha256(content: str) -> str: + return hashlib.sha256(content.encode("utf-8")).hexdigest() diff --git a/src/crewai_tools/rag/source_content.py b/src/crewai_tools/rag/source_content.py new file mode 100644 index 000000000..59530c8d8 --- /dev/null +++ b/src/crewai_tools/rag/source_content.py @@ -0,0 +1,46 @@ +import os +from urllib.parse import urlparse +from typing import TYPE_CHECKING +from pathlib import Path +from functools import cached_property + +from crewai_tools.rag.misc import compute_sha256 + +if TYPE_CHECKING: + from crewai_tools.rag.data_types import DataType + + +class SourceContent: + def __init__(self, source: str | Path): + self.source = str(source) + + def is_url(self) -> bool: + if not isinstance(self.source, str): + return False + try: + parsed_url = urlparse(self.source) + return bool(parsed_url.scheme and parsed_url.netloc) + except Exception: + return False + + def path_exists(self) -> bool: + return os.path.exists(self.source) + + @cached_property + def data_type(self) -> "DataType": + from crewai_tools.rag.data_types import DataTypes + + return DataTypes.from_content(self.source) + + @cached_property + def source_ref(self) -> str: + """" + Returns the source reference for the content. + If the content is a URL or a local file, returns the source. + Otherwise, returns the hash of the content. + """ + + if self.is_url() or self.path_exists(): + return self.source + + return compute_sha256(self.source) diff --git a/tests/rag/__init__.py b/tests/rag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/rag/test_csv_loader.py b/tests/rag/test_csv_loader.py new file mode 100644 index 000000000..596cb4d58 --- /dev/null +++ b/tests/rag/test_csv_loader.py @@ -0,0 +1,130 @@ +import os +import tempfile +import pytest +from unittest.mock import patch, Mock + +from crewai_tools.rag.loaders.csv_loader import CSVLoader +from crewai_tools.rag.base_loader import LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +@pytest.fixture +def temp_csv_file(): + created_files = [] + + def _create(content: str): + f = tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) + f.write(content) + f.close() + created_files.append(f.name) + return f.name + + yield _create + + for path in created_files: + os.unlink(path) + + +class TestCSVLoader: + def test_load_csv_from_file(self, temp_csv_file): + path = temp_csv_file("name,age,city\nJohn,25,New York\nJane,30,Chicago") + loader = CSVLoader() + result = loader.load(SourceContent(path)) + + assert isinstance(result, LoaderResult) + assert "Headers: name | age | city" in result.content + assert "Row 1: name: John | age: 25 | city: New York" in result.content + assert "Row 2: name: Jane | age: 30 | city: Chicago" in result.content + assert result.metadata == { + "format": "csv", + "columns": ["name", "age", "city"], + "rows": 2, + } + assert result.source == path + assert result.doc_id + + def test_load_csv_with_empty_values(self, temp_csv_file): + path = temp_csv_file("name,age,city\nJohn,,New York\n,30,") + result = CSVLoader().load(SourceContent(path)) + + assert "Row 1: name: John | city: New York" in result.content + assert "Row 2: age: 30" in result.content + assert result.metadata["rows"] == 2 + + def test_load_csv_malformed(self, temp_csv_file): + path = temp_csv_file("invalid,csv\nunclosed quote \"missing") + result = CSVLoader().load(SourceContent(path)) + + assert "Headers: invalid | csv" in result.content + assert 'Row 1: invalid: unclosed quote "missing' in result.content + assert result.metadata["columns"] == ["invalid", "csv"] + + def test_load_csv_empty_file(self, temp_csv_file): + path = temp_csv_file("") + result = CSVLoader().load(SourceContent(path)) + + assert result.content == "" + assert result.metadata["rows"] == 0 + + def test_load_csv_text_input(self): + raw_csv = "col1,col2\nvalue1,value2\nvalue3,value4" + result = CSVLoader().load(SourceContent(raw_csv)) + + assert "Headers: col1 | col2" in result.content + assert "Row 1: col1: value1 | col2: value2" in result.content + assert "Row 2: col1: value3 | col2: value4" in result.content + assert result.metadata["columns"] == ["col1", "col2"] + assert result.metadata["rows"] == 2 + + def test_doc_id_is_deterministic(self, temp_csv_file): + path = temp_csv_file("name,value\ntest,123") + loader = CSVLoader() + + result1 = loader.load(SourceContent(path)) + result2 = loader.load(SourceContent(path)) + + assert result1.doc_id == result2.doc_id + + @patch("requests.get") + def test_load_csv_from_url(self, mock_get): + mock_get.return_value = Mock( + text="name,value\ntest,123", + raise_for_status=Mock(return_value=None) + ) + + result = CSVLoader().load(SourceContent("https://example.com/data.csv")) + + assert "Headers: name | value" in result.content + assert "Row 1: name: test | value: 123" in result.content + headers = mock_get.call_args[1]["headers"] + assert "text/csv" in headers["Accept"] + assert "crewai-tools CSVLoader" in headers["User-Agent"] + + @patch("requests.get") + def test_load_csv_with_custom_headers(self, mock_get): + mock_get.return_value = Mock( + text="data,value\ntest,456", + raise_for_status=Mock(return_value=None) + ) + headers = {"Authorization": "Bearer token", "Custom-Header": "value"} + result = CSVLoader().load(SourceContent("https://example.com/data.csv"), headers=headers) + + assert "Headers: data | value" in result.content + assert mock_get.call_args[1]["headers"] == headers + + @patch("requests.get") + def test_csv_loader_handles_network_errors(self, mock_get): + mock_get.side_effect = Exception("Network error") + loader = CSVLoader() + + with pytest.raises(ValueError, match="Error fetching CSV from URL"): + loader.load(SourceContent("https://example.com/data.csv")) + + @patch("requests.get") + def test_csv_loader_handles_http_error(self, mock_get): + mock_get.return_value = Mock() + mock_get.return_value.raise_for_status.side_effect = Exception("404 Not Found") + loader = CSVLoader() + + with pytest.raises(ValueError, match="Error fetching CSV from URL"): + loader.load(SourceContent("https://example.com/notfound.csv")) diff --git a/tests/rag/test_directory_loader.py b/tests/rag/test_directory_loader.py new file mode 100644 index 000000000..7ddb38341 --- /dev/null +++ b/tests/rag/test_directory_loader.py @@ -0,0 +1,149 @@ +import os +import tempfile +import pytest + +from crewai_tools.rag.loaders.directory_loader import DirectoryLoader +from crewai_tools.rag.base_loader import LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +@pytest.fixture +def temp_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +class TestDirectoryLoader: + def _create_file(self, directory, filename, content="test content"): + path = os.path.join(directory, filename) + with open(path, "w") as f: + f.write(content) + return path + + def test_load_non_recursive(self, temp_directory): + self._create_file(temp_directory, "file1.txt") + self._create_file(temp_directory, "file2.txt") + subdir = os.path.join(temp_directory, "subdir") + os.makedirs(subdir) + self._create_file(subdir, "file3.txt") + + loader = DirectoryLoader() + result = loader.load(SourceContent(temp_directory), recursive=False) + + assert isinstance(result, LoaderResult) + assert "file1.txt" in result.content + assert "file2.txt" in result.content + assert "file3.txt" not in result.content + assert result.metadata["total_files"] == 2 + + def test_load_recursive(self, temp_directory): + self._create_file(temp_directory, "file1.txt") + nested = os.path.join(temp_directory, "subdir", "nested") + os.makedirs(nested) + self._create_file(os.path.join(temp_directory, "subdir"), "file2.txt") + self._create_file(nested, "file3.txt") + + loader = DirectoryLoader() + result = loader.load(SourceContent(temp_directory), recursive=True) + + assert all(f"file{i}.txt" in result.content for i in range(1, 4)) + + def test_include_and_exclude_extensions(self, temp_directory): + self._create_file(temp_directory, "a.txt") + self._create_file(temp_directory, "b.py") + self._create_file(temp_directory, "c.md") + + loader = DirectoryLoader() + result = loader.load(SourceContent(temp_directory), include_extensions=[".txt", ".py"]) + assert "a.txt" in result.content + assert "b.py" in result.content + assert "c.md" not in result.content + + result2 = loader.load(SourceContent(temp_directory), exclude_extensions=[".py", ".md"]) + assert "a.txt" in result2.content + assert "b.py" not in result2.content + assert "c.md" not in result2.content + + def test_max_files_limit(self, temp_directory): + for i in range(5): + self._create_file(temp_directory, f"file{i}.txt") + + loader = DirectoryLoader() + result = loader.load(SourceContent(temp_directory), max_files=3) + + assert result.metadata["total_files"] == 3 + assert all(f"file{i}.txt" in result.content for i in range(3)) + + def test_hidden_files_and_dirs_excluded(self, temp_directory): + self._create_file(temp_directory, "visible.txt", "visible") + self._create_file(temp_directory, ".hidden.txt", "hidden") + + hidden_dir = os.path.join(temp_directory, ".hidden") + os.makedirs(hidden_dir) + self._create_file(hidden_dir, "inside_hidden.txt") + + loader = DirectoryLoader() + result = loader.load(SourceContent(temp_directory), recursive=True) + + assert "visible.txt" in result.content + assert ".hidden.txt" not in result.content + assert "inside_hidden.txt" not in result.content + + def test_directory_does_not_exist(self): + loader = DirectoryLoader() + with pytest.raises(FileNotFoundError, match="Directory does not exist"): + loader.load(SourceContent("/path/does/not/exist")) + + def test_path_is_not_a_directory(self): + with tempfile.NamedTemporaryFile() as f: + loader = DirectoryLoader() + with pytest.raises(ValueError, match="Path is not a directory"): + loader.load(SourceContent(f.name)) + + def test_url_not_supported(self): + loader = DirectoryLoader() + with pytest.raises(ValueError, match="URL directory loading is not supported"): + loader.load(SourceContent("https://example.com")) + + def test_processing_error_handling(self, temp_directory): + self._create_file(temp_directory, "valid.txt") + error_file = self._create_file(temp_directory, "error.txt") + + loader = DirectoryLoader() + original_method = loader._process_single_file + + def mock(file_path): + if "error" in file_path: + raise ValueError("Mock error") + return original_method(file_path) + + loader._process_single_file = mock + result = loader.load(SourceContent(temp_directory)) + + assert "valid.txt" in result.content + assert "error.txt (ERROR)" in result.content + assert result.metadata["errors"] == 1 + assert len(result.metadata["error_details"]) == 1 + + def test_metadata_structure(self, temp_directory): + self._create_file(temp_directory, "test.txt", "Sample") + + loader = DirectoryLoader() + result = loader.load(SourceContent(temp_directory)) + metadata = result.metadata + + expected_keys = { + "format", "directory_path", "total_files", "processed_files", + "errors", "file_details", "error_details" + } + + assert expected_keys.issubset(metadata) + assert all(k in metadata["file_details"][0] for k in ("path", "metadata", "source")) + + def test_empty_directory(self, temp_directory): + loader = DirectoryLoader() + result = loader.load(SourceContent(temp_directory)) + + assert result.content == "" + assert result.metadata["total_files"] == 0 + assert result.metadata["processed_files"] == 0 diff --git a/tests/rag/test_docx_loader.py b/tests/rag/test_docx_loader.py new file mode 100644 index 000000000..f95aa0662 --- /dev/null +++ b/tests/rag/test_docx_loader.py @@ -0,0 +1,135 @@ +import tempfile +import pytest +from unittest.mock import patch, Mock + +from crewai_tools.rag.loaders.docx_loader import DOCXLoader +from crewai_tools.rag.base_loader import LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class TestDOCXLoader: + @patch('docx.Document') + def test_load_docx_from_file(self, mock_docx_class): + mock_doc = Mock() + mock_doc.paragraphs = [ + Mock(text="First paragraph"), + Mock(text="Second paragraph"), + Mock(text=" ") # Blank paragraph + ] + mock_doc.tables = [] + mock_docx_class.return_value = mock_doc + + with tempfile.NamedTemporaryFile(suffix='.docx') as f: + loader = DOCXLoader() + result = loader.load(SourceContent(f.name)) + + assert isinstance(result, LoaderResult) + assert result.content == "First paragraph\nSecond paragraph" + assert result.metadata == {"format": "docx", "paragraphs": 3, "tables": 0} + assert result.source == f.name + + @patch('docx.Document') + def test_load_docx_with_tables(self, mock_docx_class): + mock_doc = Mock() + mock_doc.paragraphs = [Mock(text="Document with table")] + mock_doc.tables = [Mock(), Mock()] + mock_docx_class.return_value = mock_doc + + with tempfile.NamedTemporaryFile(suffix='.docx') as f: + loader = DOCXLoader() + result = loader.load(SourceContent(f.name)) + + assert result.metadata["tables"] == 2 + + @patch('requests.get') + @patch('docx.Document') + @patch('tempfile.NamedTemporaryFile') + @patch('os.unlink') + def test_load_docx_from_url(self, mock_unlink, mock_tempfile, mock_docx_class, mock_get): + mock_get.return_value = Mock(content=b"fake docx content", raise_for_status=Mock()) + + mock_temp = Mock(name="/tmp/temp_docx_file.docx") + mock_temp.__enter__ = Mock(return_value=mock_temp) + mock_temp.__exit__ = Mock(return_value=None) + mock_tempfile.return_value = mock_temp + + mock_doc = Mock() + mock_doc.paragraphs = [Mock(text="Content from URL")] + mock_doc.tables = [] + mock_docx_class.return_value = mock_doc + + loader = DOCXLoader() + result = loader.load(SourceContent("https://example.com/test.docx")) + + assert "Content from URL" in result.content + assert result.source == "https://example.com/test.docx" + + headers = mock_get.call_args[1]['headers'] + assert "application/vnd.openxmlformats-officedocument.wordprocessingml.document" in headers['Accept'] + assert "crewai-tools DOCXLoader" in headers['User-Agent'] + + mock_temp.write.assert_called_once_with(b"fake docx content") + + @patch('requests.get') + @patch('docx.Document') + def test_load_docx_from_url_with_custom_headers(self, mock_docx_class, mock_get): + mock_get.return_value = Mock(content=b"fake docx content", raise_for_status=Mock()) + mock_docx_class.return_value = Mock(paragraphs=[], tables=[]) + + loader = DOCXLoader() + custom_headers = {"Authorization": "Bearer token"} + + with patch('tempfile.NamedTemporaryFile'), patch('os.unlink'): + loader.load(SourceContent("https://example.com/test.docx"), headers=custom_headers) + + assert mock_get.call_args[1]['headers'] == custom_headers + + @patch('requests.get') + def test_load_docx_url_download_error(self, mock_get): + mock_get.side_effect = Exception("Network error") + + loader = DOCXLoader() + with pytest.raises(ValueError, match="Error fetching DOCX from URL"): + loader.load(SourceContent("https://example.com/test.docx")) + + @patch('requests.get') + def test_load_docx_url_http_error(self, mock_get): + mock_get.return_value = Mock(raise_for_status=Mock(side_effect=Exception("404 Not Found"))) + + loader = DOCXLoader() + with pytest.raises(ValueError, match="Error fetching DOCX from URL"): + loader.load(SourceContent("https://example.com/notfound.docx")) + + def test_load_docx_invalid_source(self): + loader = DOCXLoader() + with pytest.raises(ValueError, match="Source must be a valid file path or URL"): + loader.load(SourceContent("not_a_file_or_url")) + + @patch('docx.Document') + def test_load_docx_parsing_error(self, mock_docx_class): + mock_docx_class.side_effect = Exception("Invalid DOCX file") + + with tempfile.NamedTemporaryFile(suffix='.docx') as f: + loader = DOCXLoader() + with pytest.raises(ValueError, match="Error loading DOCX file"): + loader.load(SourceContent(f.name)) + + @patch('docx.Document') + def test_load_docx_empty_document(self, mock_docx_class): + mock_docx_class.return_value = Mock(paragraphs=[], tables=[]) + + with tempfile.NamedTemporaryFile(suffix='.docx') as f: + loader = DOCXLoader() + result = loader.load(SourceContent(f.name)) + + assert result.content == "" + assert result.metadata == {"paragraphs": 0, "tables": 0, "format": "docx"} + + @patch('docx.Document') + def test_docx_doc_id_generation(self, mock_docx_class): + mock_docx_class.return_value = Mock(paragraphs=[Mock(text="Consistent content")], tables=[]) + + with tempfile.NamedTemporaryFile(suffix='.docx') as f: + loader = DOCXLoader() + source = SourceContent(f.name) + assert loader.load(source).doc_id == loader.load(source).doc_id diff --git a/tests/rag/test_json_loader.py b/tests/rag/test_json_loader.py new file mode 100644 index 000000000..b57480e16 --- /dev/null +++ b/tests/rag/test_json_loader.py @@ -0,0 +1,180 @@ +import json +import os +import tempfile +import pytest +from unittest.mock import patch, Mock + +from crewai_tools.rag.loaders.json_loader import JSONLoader +from crewai_tools.rag.base_loader import LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class TestJSONLoader: + def _create_temp_json_file(self, data) -> str: + """Helper to write JSON data to a temporary file and return its path.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(data, f) + return f.name + + def _create_temp_raw_file(self, content: str) -> str: + """Helper to write raw content to a temporary file and return its path.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + f.write(content) + return f.name + + def _load_from_path(self, path) -> LoaderResult: + loader = JSONLoader() + return loader.load(SourceContent(path)) + + def test_load_json_dict(self): + path = self._create_temp_json_file({"name": "John", "age": 30, "items": ["a", "b", "c"]}) + try: + result = self._load_from_path(path) + assert isinstance(result, LoaderResult) + assert all(k in result.content for k in ["name", "John", "age", "30"]) + assert result.metadata == { + "format": "json", "type": "dict", "size": 3 + } + assert result.source == path + finally: + os.unlink(path) + + def test_load_json_list(self): + path = self._create_temp_json_file([ + {"id": 1, "name": "Item 1"}, + {"id": 2, "name": "Item 2"}, + ]) + try: + result = self._load_from_path(path) + assert result.metadata["type"] == "list" + assert result.metadata["size"] == 2 + assert all(item in result.content for item in ["Item 1", "Item 2"]) + finally: + os.unlink(path) + + @pytest.mark.parametrize("value, expected_type", [ + ("simple string value", "str"), + (42, "int"), + ]) + def test_load_json_primitives(self, value, expected_type): + path = self._create_temp_json_file(value) + try: + result = self._load_from_path(path) + assert result.metadata["type"] == expected_type + assert result.metadata["size"] == 1 + assert str(value) in result.content + finally: + os.unlink(path) + + def test_load_malformed_json(self): + path = self._create_temp_raw_file('{"invalid": json,}') + try: + result = self._load_from_path(path) + assert result.metadata["format"] == "json" + assert "parse_error" in result.metadata + assert result.content == '{"invalid": json,}' + finally: + os.unlink(path) + + def test_load_empty_file(self): + path = self._create_temp_raw_file('') + try: + result = self._load_from_path(path) + assert "parse_error" in result.metadata + assert result.content == '' + finally: + os.unlink(path) + + def test_load_text_input(self): + json_text = '{"message": "hello", "count": 5}' + loader = JSONLoader() + result = loader.load(SourceContent(json_text)) + assert all(part in result.content for part in ["message", "hello", "count", "5"]) + assert result.metadata["type"] == "dict" + assert result.metadata["size"] == 2 + + def test_load_complex_nested_json(self): + data = { + "users": [ + {"id": 1, "profile": {"name": "Alice", "settings": {"theme": "dark"}}}, + {"id": 2, "profile": {"name": "Bob", "settings": {"theme": "light"}}} + ], + "meta": {"total": 2, "version": "1.0"} + } + path = self._create_temp_json_file(data) + try: + result = self._load_from_path(path) + for value in ["Alice", "Bob", "dark", "light"]: + assert value in result.content + assert result.metadata["size"] == 2 # top-level keys + finally: + os.unlink(path) + + def test_consistent_doc_id(self): + path = self._create_temp_json_file({"test": "data"}) + try: + result1 = self._load_from_path(path) + result2 = self._load_from_path(path) + assert result1.doc_id == result2.doc_id + finally: + os.unlink(path) + + # ------------------------------ + # URL-based tests + # ------------------------------ + + @patch('requests.get') + def test_url_response_valid_json(self, mock_get): + mock_get.return_value = Mock( + text='{"key": "value", "number": 123}', + json=Mock(return_value={"key": "value", "number": 123}), + raise_for_status=Mock() + ) + + loader = JSONLoader() + result = loader.load(SourceContent("https://api.example.com/data.json")) + + assert all(val in result.content for val in ["key", "value", "number", "123"]) + headers = mock_get.call_args[1]['headers'] + assert "application/json" in headers['Accept'] + assert "crewai-tools JSONLoader" in headers['User-Agent'] + + @patch('requests.get') + def test_url_response_not_json(self, mock_get): + mock_get.return_value = Mock( + text='{"key": "value"}', + json=Mock(side_effect=ValueError("Not JSON")), + raise_for_status=Mock() + ) + + loader = JSONLoader() + result = loader.load(SourceContent("https://example.com/data.json")) + assert all(part in result.content for part in ["key", "value"]) + + @patch('requests.get') + def test_url_with_custom_headers(self, mock_get): + mock_get.return_value = Mock( + text='{"data": "test"}', + json=Mock(return_value={"data": "test"}), + raise_for_status=Mock() + ) + headers = {"Authorization": "Bearer token", "Custom-Header": "value"} + + loader = JSONLoader() + loader.load(SourceContent("https://api.example.com/data.json"), headers=headers) + + assert mock_get.call_args[1]['headers'] == headers + + @patch('requests.get') + def test_url_network_failure(self, mock_get): + mock_get.side_effect = Exception("Network error") + loader = JSONLoader() + with pytest.raises(ValueError, match="Error fetching JSON from URL"): + loader.load(SourceContent("https://api.example.com/data.json")) + + @patch('requests.get') + def test_url_http_error(self, mock_get): + mock_get.return_value = Mock(raise_for_status=Mock(side_effect=Exception("404"))) + loader = JSONLoader() + with pytest.raises(ValueError, match="Error fetching JSON from URL"): + loader.load(SourceContent("https://api.example.com/404.json")) diff --git a/tests/rag/test_mdx_loader.py b/tests/rag/test_mdx_loader.py new file mode 100644 index 000000000..ef7944c28 --- /dev/null +++ b/tests/rag/test_mdx_loader.py @@ -0,0 +1,176 @@ +import os +import tempfile +import pytest +from unittest.mock import patch, Mock + +from crewai_tools.rag.loaders.mdx_loader import MDXLoader +from crewai_tools.rag.base_loader import LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class TestMDXLoader: + + def _write_temp_mdx(self, content): + f = tempfile.NamedTemporaryFile(mode='w', suffix='.mdx', delete=False) + f.write(content) + f.close() + return f.name + + def _load_from_file(self, content): + path = self._write_temp_mdx(content) + try: + loader = MDXLoader() + return loader.load(SourceContent(path)), path + finally: + os.unlink(path) + + def test_load_basic_mdx_file(self): + content = """ +import Component from './Component' +export const meta = { title: 'Test' } + +# Test MDX File + +This is a **markdown** file with JSX. + + + +Some more content. + +
+

Nested content

+
+""" + result, path = self._load_from_file(content) + + assert isinstance(result, LoaderResult) + assert all(tag not in result.content for tag in ["import", "export", ""]) + assert all(text in result.content for text in ["# Test MDX File", "markdown", "Some more content", "Nested content"]) + assert result.metadata["format"] == "mdx" + assert result.source == path + + def test_mdx_multiple_imports_exports(self): + content = """ +import React from 'react' +import { useState } from 'react' +import CustomComponent from './custom' + +export default function Layout() { return null } +export const config = { test: true } + +# Content + +Regular markdown content here. +""" + result, _ = self._load_from_file(content) + assert "# Content" in result.content + assert "Regular markdown content here." in result.content + assert "import" not in result.content and "export" not in result.content + + def test_complex_jsx_cleanup(self): + content = """ +# MDX with Complex JSX + +
+ Info: This is important information. +
  • Item 1
  • Item 2
+
+ +Regular paragraph text. + +Nested content inside component +""" + result, _ = self._load_from_file(content) + assert all(tag not in result.content for tag in ["", "
    ", " +

    Only JSX content

    +

    No markdown here

    + +""" + result, _ = self._load_from_file(content) + assert all(tag not in result.content for tag in ["
    ", "

    ", "

    "]) + assert "Only JSX content" in result.content + assert "No markdown here" in result.content + + @patch('requests.get') + def test_load_mdx_from_url(self, mock_get): + mock_get.return_value = Mock(text="# MDX from URL\n\nContent here.\n\n", raise_for_status=lambda: None) + loader = MDXLoader() + result = loader.load(SourceContent("https://example.com/content.mdx")) + assert "# MDX from URL" in result.content + assert "" not in result.content + + @patch('requests.get') + def test_load_mdx_with_custom_headers(self, mock_get): + mock_get.return_value = Mock(text="# Custom headers test", raise_for_status=lambda: None) + loader = MDXLoader() + loader.load(SourceContent("https://example.com"), headers={"Authorization": "Bearer token"}) + assert mock_get.call_args[1]['headers'] == {"Authorization": "Bearer token"} + + @patch('requests.get') + def test_mdx_url_fetch_error(self, mock_get): + mock_get.side_effect = Exception("Network error") + with pytest.raises(ValueError, match="Error fetching MDX from URL"): + MDXLoader().load(SourceContent("https://example.com")) + + def test_load_inline_mdx_text(self): + content = """# Inline MDX\n\nimport Something from 'somewhere'\n\nContent with .\n\nexport const meta = { title: 'Test' }""" + loader = MDXLoader() + result = loader.load(SourceContent(content)) + assert "# Inline MDX" in result.content + assert "Content with ." in result.content + + def test_empty_result_after_cleaning(self): + content = """ +import Something from 'somewhere' +export const config = {} +

    +""" + result, _ = self._load_from_file(content) + assert result.content.strip() == "" + + def test_edge_case_parsing(self): + content = """ +# Title + + +Multi-line +JSX content + + +import { a, b } from 'module' + +export { x, y } + +Final text. +""" + result, _ = self._load_from_file(content) + assert "# Title" in result.content + assert "JSX content" in result.content + assert "Final text." in result.content + assert all(phrase not in result.content for phrase in ["import {", "export {", ""]) diff --git a/tests/rag/test_text_loaders.py b/tests/rag/test_text_loaders.py new file mode 100644 index 000000000..e72738778 --- /dev/null +++ b/tests/rag/test_text_loaders.py @@ -0,0 +1,160 @@ +import hashlib +import os +import tempfile +import pytest + +from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader +from crewai_tools.rag.base_loader import LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +def write_temp_file(content, suffix=".txt", encoding="utf-8"): + with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False, encoding=encoding) as f: + f.write(content) + return f.name + + +def cleanup_temp_file(path): + try: + os.unlink(path) + except FileNotFoundError: + pass + + +class TestTextFileLoader: + def test_basic_text_file(self): + content = "This is test content\nWith multiple lines\nAnd more text" + path = write_temp_file(content) + try: + result = TextFileLoader().load(SourceContent(path)) + assert isinstance(result, LoaderResult) + assert result.content == content + assert result.source == path + assert result.doc_id + assert result.metadata in (None, {}) + finally: + cleanup_temp_file(path) + + def test_empty_file(self): + path = write_temp_file("") + try: + result = TextFileLoader().load(SourceContent(path)) + assert result.content == "" + finally: + cleanup_temp_file(path) + + def test_unicode_content(self): + content = "Hello 世界 🌍 émojis 🎉 åäö" + path = write_temp_file(content) + try: + result = TextFileLoader().load(SourceContent(path)) + assert content in result.content + finally: + cleanup_temp_file(path) + + def test_large_file(self): + content = "\n".join(f"Line {i}" for i in range(100)) + path = write_temp_file(content) + try: + result = TextFileLoader().load(SourceContent(path)) + assert "Line 0" in result.content + assert "Line 99" in result.content + assert result.content.count("\n") == 99 + finally: + cleanup_temp_file(path) + + def test_missing_file(self): + with pytest.raises(FileNotFoundError): + TextFileLoader().load(SourceContent("/nonexistent/path.txt")) + + def test_permission_denied(self): + path = write_temp_file("Some content") + os.chmod(path, 0o000) + try: + with pytest.raises(PermissionError): + TextFileLoader().load(SourceContent(path)) + finally: + os.chmod(path, 0o644) + cleanup_temp_file(path) + + def test_doc_id_consistency(self): + content = "Consistent content" + path = write_temp_file(content) + try: + loader = TextFileLoader() + result1 = loader.load(SourceContent(path)) + result2 = loader.load(SourceContent(path)) + expected_id = hashlib.sha256((path + content).encode("utf-8")).hexdigest() + assert result1.doc_id == result2.doc_id == expected_id + finally: + cleanup_temp_file(path) + + def test_various_extensions(self): + content = "Same content" + for ext in [".txt", ".md", ".log", ".json"]: + path = write_temp_file(content, suffix=ext) + try: + result = TextFileLoader().load(SourceContent(path)) + assert result.content == content + finally: + cleanup_temp_file(path) + + +class TestTextLoader: + def test_basic_text(self): + content = "Raw text" + result = TextLoader().load(SourceContent(content)) + expected_hash = hashlib.sha256(content.encode("utf-8")).hexdigest() + assert result.content == content + assert result.source == expected_hash + assert result.doc_id == expected_hash + + def test_multiline_text(self): + content = "Line 1\nLine 2\nLine 3" + result = TextLoader().load(SourceContent(content)) + assert "Line 2" in result.content + + def test_empty_text(self): + result = TextLoader().load(SourceContent("")) + assert result.content == "" + assert result.source == hashlib.sha256("".encode("utf-8")).hexdigest() + + def test_unicode_text(self): + content = "世界 🌍 émojis 🎉 åäö" + result = TextLoader().load(SourceContent(content)) + assert content in result.content + + def test_special_characters(self): + content = "!@#$$%^&*()_+-=~`{}[]\\|;:'\",.<>/?" + result = TextLoader().load(SourceContent(content)) + assert result.content == content + + def test_doc_id_uniqueness(self): + result1 = TextLoader().load(SourceContent("A")) + result2 = TextLoader().load(SourceContent("B")) + assert result1.doc_id != result2.doc_id + + def test_whitespace_text(self): + content = " \n\t " + result = TextLoader().load(SourceContent(content)) + assert result.content == content + + def test_long_text(self): + content = "A" * 10000 + result = TextLoader().load(SourceContent(content)) + assert len(result.content) == 10000 + + +class TestTextLoadersIntegration: + def test_consistency_between_loaders(self): + content = "Consistent content" + text_result = TextLoader().load(SourceContent(content)) + file_path = write_temp_file(content) + try: + file_result = TextFileLoader().load(SourceContent(file_path)) + + assert text_result.content == file_result.content + assert text_result.source != file_result.source + assert text_result.doc_id != file_result.doc_id + finally: + cleanup_temp_file(file_path) diff --git a/tests/rag/test_webpage_loader.py b/tests/rag/test_webpage_loader.py new file mode 100644 index 000000000..9e02f410b --- /dev/null +++ b/tests/rag/test_webpage_loader.py @@ -0,0 +1,137 @@ +import pytest +from unittest.mock import patch, Mock +from crewai_tools.rag.loaders.webpage_loader import WebPageLoader +from crewai_tools.rag.base_loader import LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class TestWebPageLoader: + def setup_mock_response(self, text, status_code=200, content_type="text/html"): + response = Mock() + response.text = text + response.apparent_encoding = "utf-8" + response.status_code = status_code + response.headers = {"content-type": content_type} + return response + + def setup_mock_soup(self, text, title=None, script_style_elements=None): + soup = Mock() + soup.get_text.return_value = text + soup.title = Mock(string=title) if title is not None else None + soup.return_value = script_style_elements or [] + return soup + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_load_basic_webpage(self, mock_bs, mock_get): + mock_get.return_value = self.setup_mock_response("Test Page

    Test content

    ") + mock_bs.return_value = self.setup_mock_soup("Test content", title="Test Page") + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com")) + + assert isinstance(result, LoaderResult) + assert result.content == "Test content" + assert result.metadata["title"] == "Test Page" + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get): + html = """ + Page with Scripts +

    Visible content

    + """ + mock_get.return_value = self.setup_mock_response(html) + scripts = [Mock(), Mock()] + styles = [Mock()] + for el in scripts + styles: + el.decompose = Mock() + mock_bs.return_value = self.setup_mock_soup("Page with Scripts Visible content", title="Page with Scripts", script_style_elements=scripts + styles) + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com/with-scripts")) + + assert "Visible content" in result.content + for el in scripts + styles: + el.decompose.assert_called_once() + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_text_cleaning_and_title_handling(self, mock_bs, mock_get): + mock_get.return_value = self.setup_mock_response("

    Messy text

    ") + mock_bs.return_value = self.setup_mock_soup("Text with extra spaces\n\n More\t text \n\n", title=None) + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com/messy-text")) + assert result.content is not None + assert result.metadata["title"] == "" + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_empty_or_missing_title(self, mock_bs, mock_get): + for title in [None, ""]: + mock_get.return_value = self.setup_mock_response("Content") + mock_bs.return_value = self.setup_mock_soup("Content", title=title) + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com")) + assert result.metadata["title"] == "" + + @patch('requests.get') + def test_custom_and_default_headers(self, mock_get): + mock_get.return_value = self.setup_mock_response("Test") + custom_headers = {"User-Agent": "Bot", "Authorization": "Bearer xyz", "Accept": "text/html"} + + with patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') as mock_bs: + mock_bs.return_value = self.setup_mock_soup("Test") + WebPageLoader().load(SourceContent("https://example.com"), headers=custom_headers) + + assert mock_get.call_args[1]['headers'] == custom_headers + + @patch('requests.get') + def test_error_handling(self, mock_get): + for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]: + mock_get.side_effect = error + with pytest.raises(ValueError, match="Error loading webpage"): + WebPageLoader().load(SourceContent("https://example.com")) + + @patch('requests.get') + def test_timeout_and_http_error(self, mock_get): + import requests + mock_get.side_effect = requests.Timeout("Timeout") + with pytest.raises(ValueError): + WebPageLoader().load(SourceContent("https://example.com")) + + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("404") + mock_get.side_effect = None + mock_get.return_value = mock_response + with pytest.raises(ValueError): + WebPageLoader().load(SourceContent("https://example.com/404")) + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_doc_id_consistency(self, mock_bs, mock_get): + mock_get.return_value = self.setup_mock_response("Doc") + mock_bs.return_value = self.setup_mock_soup("Doc") + + loader = WebPageLoader() + result1 = loader.load(SourceContent("https://example.com")) + result2 = loader.load(SourceContent("https://example.com")) + + assert result1.doc_id == result2.doc_id + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_status_code_and_content_type(self, mock_bs, mock_get): + for status in [200, 201, 301]: + mock_get.return_value = self.setup_mock_response(f"Status {status}", status_code=status) + mock_bs.return_value = self.setup_mock_soup(f"Status {status}") + result = WebPageLoader().load(SourceContent(f"https://example.com/{status}")) + assert result.metadata["status_code"] == status + + for ctype in ["text/html", "text/plain", "application/xhtml+xml"]: + mock_get.return_value = self.setup_mock_response("Content", content_type=ctype) + mock_bs.return_value = self.setup_mock_soup("Content") + result = WebPageLoader().load(SourceContent("https://example.com")) + assert result.metadata["content_type"] == ctype diff --git a/tests/rag/test_xml_loader.py b/tests/rag/test_xml_loader.py new file mode 100644 index 000000000..9e02f410b --- /dev/null +++ b/tests/rag/test_xml_loader.py @@ -0,0 +1,137 @@ +import pytest +from unittest.mock import patch, Mock +from crewai_tools.rag.loaders.webpage_loader import WebPageLoader +from crewai_tools.rag.base_loader import LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class TestWebPageLoader: + def setup_mock_response(self, text, status_code=200, content_type="text/html"): + response = Mock() + response.text = text + response.apparent_encoding = "utf-8" + response.status_code = status_code + response.headers = {"content-type": content_type} + return response + + def setup_mock_soup(self, text, title=None, script_style_elements=None): + soup = Mock() + soup.get_text.return_value = text + soup.title = Mock(string=title) if title is not None else None + soup.return_value = script_style_elements or [] + return soup + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_load_basic_webpage(self, mock_bs, mock_get): + mock_get.return_value = self.setup_mock_response("Test Page

    Test content

    ") + mock_bs.return_value = self.setup_mock_soup("Test content", title="Test Page") + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com")) + + assert isinstance(result, LoaderResult) + assert result.content == "Test content" + assert result.metadata["title"] == "Test Page" + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get): + html = """ + Page with Scripts +

    Visible content

    + """ + mock_get.return_value = self.setup_mock_response(html) + scripts = [Mock(), Mock()] + styles = [Mock()] + for el in scripts + styles: + el.decompose = Mock() + mock_bs.return_value = self.setup_mock_soup("Page with Scripts Visible content", title="Page with Scripts", script_style_elements=scripts + styles) + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com/with-scripts")) + + assert "Visible content" in result.content + for el in scripts + styles: + el.decompose.assert_called_once() + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_text_cleaning_and_title_handling(self, mock_bs, mock_get): + mock_get.return_value = self.setup_mock_response("

    Messy text

    ") + mock_bs.return_value = self.setup_mock_soup("Text with extra spaces\n\n More\t text \n\n", title=None) + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com/messy-text")) + assert result.content is not None + assert result.metadata["title"] == "" + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_empty_or_missing_title(self, mock_bs, mock_get): + for title in [None, ""]: + mock_get.return_value = self.setup_mock_response("Content") + mock_bs.return_value = self.setup_mock_soup("Content", title=title) + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com")) + assert result.metadata["title"] == "" + + @patch('requests.get') + def test_custom_and_default_headers(self, mock_get): + mock_get.return_value = self.setup_mock_response("Test") + custom_headers = {"User-Agent": "Bot", "Authorization": "Bearer xyz", "Accept": "text/html"} + + with patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') as mock_bs: + mock_bs.return_value = self.setup_mock_soup("Test") + WebPageLoader().load(SourceContent("https://example.com"), headers=custom_headers) + + assert mock_get.call_args[1]['headers'] == custom_headers + + @patch('requests.get') + def test_error_handling(self, mock_get): + for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]: + mock_get.side_effect = error + with pytest.raises(ValueError, match="Error loading webpage"): + WebPageLoader().load(SourceContent("https://example.com")) + + @patch('requests.get') + def test_timeout_and_http_error(self, mock_get): + import requests + mock_get.side_effect = requests.Timeout("Timeout") + with pytest.raises(ValueError): + WebPageLoader().load(SourceContent("https://example.com")) + + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("404") + mock_get.side_effect = None + mock_get.return_value = mock_response + with pytest.raises(ValueError): + WebPageLoader().load(SourceContent("https://example.com/404")) + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_doc_id_consistency(self, mock_bs, mock_get): + mock_get.return_value = self.setup_mock_response("Doc") + mock_bs.return_value = self.setup_mock_soup("Doc") + + loader = WebPageLoader() + result1 = loader.load(SourceContent("https://example.com")) + result2 = loader.load(SourceContent("https://example.com")) + + assert result1.doc_id == result2.doc_id + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_status_code_and_content_type(self, mock_bs, mock_get): + for status in [200, 201, 301]: + mock_get.return_value = self.setup_mock_response(f"Status {status}", status_code=status) + mock_bs.return_value = self.setup_mock_soup(f"Status {status}") + result = WebPageLoader().load(SourceContent(f"https://example.com/{status}")) + assert result.metadata["status_code"] == status + + for ctype in ["text/html", "text/plain", "application/xhtml+xml"]: + mock_get.return_value = self.setup_mock_response("Content", content_type=ctype) + mock_bs.return_value = self.setup_mock_soup("Content") + result = WebPageLoader().load(SourceContent("https://example.com")) + assert result.metadata["content_type"] == ctype