diff --git a/src/crewai_tools/adapters/rag_adapter.py b/src/crewai_tools/adapters/rag_adapter.py
new file mode 100644
index 000000000..78011328c
--- /dev/null
+++ b/src/crewai_tools/adapters/rag_adapter.py
@@ -0,0 +1,41 @@
+from typing import Any, Optional
+
+from crewai_tools.rag.core import RAG
+from crewai_tools.tools.rag.rag_tool import Adapter
+
+
+class RAGAdapter(Adapter):
+ def __init__(
+ self,
+ collection_name: str = "crewai_knowledge_base",
+ persist_directory: Optional[str] = None,
+ embedding_model: str = "text-embedding-3-small",
+ top_k: int = 5,
+ embedding_api_key: Optional[str] = None,
+ **embedding_kwargs
+ ):
+ super().__init__()
+
+ # Prepare embedding configuration
+ embedding_config = {
+ "api_key": embedding_api_key,
+ **embedding_kwargs
+ }
+
+ self._adapter = RAG(
+ collection_name=collection_name,
+ persist_directory=persist_directory,
+ embedding_model=embedding_model,
+ top_k=top_k,
+ embedding_config=embedding_config
+ )
+
+ def query(self, question: str) -> str:
+ return self._adapter.query(question)
+
+ def add(
+ self,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+ self._adapter.add(*args, **kwargs)
diff --git a/src/crewai_tools/rag/__init__.py b/src/crewai_tools/rag/__init__.py
new file mode 100644
index 000000000..8d08b2907
--- /dev/null
+++ b/src/crewai_tools/rag/__init__.py
@@ -0,0 +1,8 @@
+from crewai_tools.rag.core import RAG, EmbeddingService
+from crewai_tools.rag.data_types import DataType
+
+__all__ = [
+ "RAG",
+ "EmbeddingService",
+ "DataType",
+]
diff --git a/src/crewai_tools/rag/base_loader.py b/src/crewai_tools/rag/base_loader.py
new file mode 100644
index 000000000..e38d6f8c1
--- /dev/null
+++ b/src/crewai_tools/rag/base_loader.py
@@ -0,0 +1,37 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Optional
+from pydantic import BaseModel, Field
+
+from crewai_tools.rag.misc import compute_sha256
+from crewai_tools.rag.source_content import SourceContent
+
+
+class LoaderResult(BaseModel):
+ content: str = Field(description="The text content of the source")
+ source: str = Field(description="The source of the content", default="unknown")
+ metadata: Dict[str, Any] = Field(description="The metadata of the source", default_factory=dict)
+ doc_id: str = Field(description="The id of the document")
+
+
+class BaseLoader(ABC):
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
+ self.config = config or {}
+
+ @abstractmethod
+ def load(self, content: SourceContent, **kwargs) -> LoaderResult:
+ ...
+
+ def generate_doc_id(self, source_ref: str | None = None, content: str | None = None) -> str:
+ """
+ Generate a unique document id based on the source reference and content.
+ If the source reference is not provided, the content is used as the source reference.
+ If the content is not provided, the source reference is used as the content.
+ If both are provided, the source reference is used as the content.
+
+ Both are optional because the TEXT content type does not have a source reference. In this case, the content is used as the source reference.
+ """
+
+ source_ref = source_ref or ""
+ content = content or ""
+
+ return compute_sha256(source_ref + content)
diff --git a/src/crewai_tools/rag/chunkers/__init__.py b/src/crewai_tools/rag/chunkers/__init__.py
new file mode 100644
index 000000000..f48483391
--- /dev/null
+++ b/src/crewai_tools/rag/chunkers/__init__.py
@@ -0,0 +1,15 @@
+from crewai_tools.rag.chunkers.base_chunker import BaseChunker
+from crewai_tools.rag.chunkers.default_chunker import DefaultChunker
+from crewai_tools.rag.chunkers.text_chunker import TextChunker, DocxChunker, MdxChunker
+from crewai_tools.rag.chunkers.structured_chunker import CsvChunker, JsonChunker, XmlChunker
+
+__all__ = [
+ "BaseChunker",
+ "DefaultChunker",
+ "TextChunker",
+ "DocxChunker",
+ "MdxChunker",
+ "CsvChunker",
+ "JsonChunker",
+ "XmlChunker",
+]
diff --git a/src/crewai_tools/rag/chunkers/base_chunker.py b/src/crewai_tools/rag/chunkers/base_chunker.py
new file mode 100644
index 000000000..deafbfc7a
--- /dev/null
+++ b/src/crewai_tools/rag/chunkers/base_chunker.py
@@ -0,0 +1,167 @@
+from typing import List, Optional
+import re
+
+class RecursiveCharacterTextSplitter:
+ """
+ A text splitter that recursively splits text based on a hierarchy of separators.
+ """
+
+ def __init__(
+ self,
+ chunk_size: int = 4000,
+ chunk_overlap: int = 200,
+ separators: Optional[List[str]] = None,
+ keep_separator: bool = True,
+ ):
+ """
+ Initialize the RecursiveCharacterTextSplitter.
+
+ Args:
+ chunk_size: Maximum size of each chunk
+ chunk_overlap: Number of characters to overlap between chunks
+ separators: List of separators to use for splitting (in order of preference)
+ keep_separator: Whether to keep the separator in the split text
+ """
+ if chunk_overlap >= chunk_size:
+ raise ValueError(f"Chunk overlap ({chunk_overlap}) cannot be >= chunk size ({chunk_size})")
+
+ self._chunk_size = chunk_size
+ self._chunk_overlap = chunk_overlap
+ self._keep_separator = keep_separator
+
+ self._separators = separators or [
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+
+ def split_text(self, text: str) -> List[str]:
+ return self._split_text(text, self._separators)
+
+ def _split_text(self, text: str, separators: List[str]) -> List[str]:
+ separator = separators[-1]
+ new_separators = []
+
+ for i, sep in enumerate(separators):
+ if sep == "":
+ separator = sep
+ break
+ if re.search(re.escape(sep), text):
+ separator = sep
+ new_separators = separators[i + 1:]
+ break
+
+ splits = self._split_text_with_separator(text, separator)
+
+ good_splits = []
+
+ for split in splits:
+ if len(split) < self._chunk_size:
+ good_splits.append(split)
+ else:
+ if new_separators:
+ other_info = self._split_text(split, new_separators)
+ good_splits.extend(other_info)
+ else:
+ good_splits.extend(self._split_by_characters(split))
+
+ return self._merge_splits(good_splits, separator)
+
+ def _split_text_with_separator(self, text: str, separator: str) -> List[str]:
+ if separator == "":
+ return list(text)
+
+ if self._keep_separator and separator in text:
+ parts = text.split(separator)
+ splits = []
+
+ for i, part in enumerate(parts):
+ if i == 0:
+ splits.append(part)
+ elif i == len(parts) - 1:
+ if part:
+ splits.append(separator + part)
+ else:
+ if part:
+ splits.append(separator + part)
+ else:
+ if splits:
+ splits[-1] += separator
+
+ return [s for s in splits if s]
+ else:
+ return text.split(separator)
+
+ def _split_by_characters(self, text: str) -> List[str]:
+ chunks = []
+ for i in range(0, len(text), self._chunk_size):
+ chunks.append(text[i:i + self._chunk_size])
+ return chunks
+
+ def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
+ """Merge splits into chunks with proper overlap."""
+ docs = []
+ current_doc = []
+ total = 0
+
+ for split in splits:
+ split_len = len(split)
+
+ if total + split_len > self._chunk_size and current_doc:
+ if separator == "":
+ doc = "".join(current_doc)
+ else:
+ doc = separator.join(current_doc)
+
+ if doc:
+ docs.append(doc)
+
+ # Handle overlap by keeping some of the previous content
+ while total > self._chunk_overlap and len(current_doc) > 1:
+ removed = current_doc.pop(0)
+ total -= len(removed)
+ if separator != "":
+ total -= len(separator)
+
+ current_doc.append(split)
+ total += split_len
+ if separator != "" and len(current_doc) > 1:
+ total += len(separator)
+
+ if current_doc:
+ if separator == "":
+ doc = "".join(current_doc)
+ else:
+ doc = separator.join(current_doc)
+
+ if doc:
+ docs.append(doc)
+
+ return docs
+
+class BaseChunker:
+ def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True):
+ """
+ Initialize the Chunker
+
+ Args:
+ chunk_size: Maximum size of each chunk
+ chunk_overlap: Number of characters to overlap between chunks
+ separators: List of separators to use for splitting
+ keep_separator: Whether to keep separators in the chunks
+ """
+
+ self._splitter = RecursiveCharacterTextSplitter(
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ separators=separators,
+ keep_separator=keep_separator,
+ )
+
+
+ def chunk(self, text: str) -> List[str]:
+ if not text or not text.strip():
+ return []
+
+ return self._splitter.split_text(text)
diff --git a/src/crewai_tools/rag/chunkers/default_chunker.py b/src/crewai_tools/rag/chunkers/default_chunker.py
new file mode 100644
index 000000000..0d0ec6935
--- /dev/null
+++ b/src/crewai_tools/rag/chunkers/default_chunker.py
@@ -0,0 +1,6 @@
+from crewai_tools.rag.chunkers.base_chunker import BaseChunker
+from typing import List, Optional
+
+class DefaultChunker(BaseChunker):
+ def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 20, separators: Optional[List[str]] = None, keep_separator: bool = True):
+ super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
diff --git a/src/crewai_tools/rag/chunkers/structured_chunker.py b/src/crewai_tools/rag/chunkers/structured_chunker.py
new file mode 100644
index 000000000..483f92588
--- /dev/null
+++ b/src/crewai_tools/rag/chunkers/structured_chunker.py
@@ -0,0 +1,49 @@
+from crewai_tools.rag.chunkers.base_chunker import BaseChunker
+from typing import List, Optional
+
+
+class CsvChunker(BaseChunker):
+ def __init__(self, chunk_size: int = 1200, chunk_overlap: int = 100, separators: Optional[List[str]] = None, keep_separator: bool = True):
+ if separators is None:
+ separators = [
+ "\nRow ", # Row boundaries (from CSVLoader format)
+ "\n", # Line breaks
+ " | ", # Column separators
+ ", ", # Comma separators
+ " ", # Word breaks
+ "", # Character level
+ ]
+ super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
+
+
+class JsonChunker(BaseChunker):
+ def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True):
+ if separators is None:
+ separators = [
+ "\n\n", # Object/array boundaries
+ "\n", # Line breaks
+ "},", # Object endings
+ "],", # Array endings
+ ", ", # Property separators
+ ": ", # Key-value separators
+ " ", # Word breaks
+ "", # Character level
+ ]
+ super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
+
+
+class XmlChunker(BaseChunker):
+ def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
+ if separators is None:
+ separators = [
+ "\n\n", # Element boundaries
+ "\n", # Line breaks
+ ">", # Tag endings
+ ". ", # Sentence endings (for text content)
+ "! ", # Exclamation endings
+ "? ", # Question endings
+ ", ", # Comma separators
+ " ", # Word breaks
+ "", # Character level
+ ]
+ super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
diff --git a/src/crewai_tools/rag/chunkers/text_chunker.py b/src/crewai_tools/rag/chunkers/text_chunker.py
new file mode 100644
index 000000000..2e76df8ab
--- /dev/null
+++ b/src/crewai_tools/rag/chunkers/text_chunker.py
@@ -0,0 +1,59 @@
+from crewai_tools.rag.chunkers.base_chunker import BaseChunker
+from typing import List, Optional
+
+
+class TextChunker(BaseChunker):
+ def __init__(self, chunk_size: int = 1500, chunk_overlap: int = 150, separators: Optional[List[str]] = None, keep_separator: bool = True):
+ if separators is None:
+ separators = [
+ "\n\n\n", # Multiple line breaks (sections)
+ "\n\n", # Paragraph breaks
+ "\n", # Line breaks
+ ". ", # Sentence endings
+ "! ", # Exclamation endings
+ "? ", # Question endings
+ "; ", # Semicolon breaks
+ ", ", # Comma breaks
+ " ", # Word breaks
+ "", # Character level
+ ]
+ super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
+
+
+class DocxChunker(BaseChunker):
+ def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
+ if separators is None:
+ separators = [
+ "\n\n\n", # Multiple line breaks (major sections)
+ "\n\n", # Paragraph breaks
+ "\n", # Line breaks
+ ". ", # Sentence endings
+ "! ", # Exclamation endings
+ "? ", # Question endings
+ "; ", # Semicolon breaks
+ ", ", # Comma breaks
+ " ", # Word breaks
+ "", # Character level
+ ]
+ super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
+
+
+class MdxChunker(BaseChunker):
+ def __init__(self, chunk_size: int = 3000, chunk_overlap: int = 300, separators: Optional[List[str]] = None, keep_separator: bool = True):
+ if separators is None:
+ separators = [
+ "\n## ", # H2 headers (major sections)
+ "\n### ", # H3 headers (subsections)
+ "\n#### ", # H4 headers (sub-subsections)
+ "\n\n", # Paragraph breaks
+ "\n```", # Code block boundaries
+ "\n", # Line breaks
+ ". ", # Sentence endings
+ "! ", # Exclamation endings
+ "? ", # Question endings
+ "; ", # Semicolon breaks
+ ", ", # Comma breaks
+ " ", # Word breaks
+ "", # Character level
+ ]
+ super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
diff --git a/src/crewai_tools/rag/chunkers/web_chunker.py b/src/crewai_tools/rag/chunkers/web_chunker.py
new file mode 100644
index 000000000..2712a6c69
--- /dev/null
+++ b/src/crewai_tools/rag/chunkers/web_chunker.py
@@ -0,0 +1,20 @@
+from crewai_tools.rag.chunkers.base_chunker import BaseChunker
+from typing import List, Optional
+
+
+class WebsiteChunker(BaseChunker):
+ def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
+ if separators is None:
+ separators = [
+ "\n\n\n", # Major section breaks
+ "\n\n", # Paragraph breaks
+ "\n", # Line breaks
+ ". ", # Sentence endings
+ "! ", # Exclamation endings
+ "? ", # Question endings
+ "; ", # Semicolon breaks
+ ", ", # Comma breaks
+ " ", # Word breaks
+ "", # Character level
+ ]
+ super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
diff --git a/src/crewai_tools/rag/core.py b/src/crewai_tools/rag/core.py
new file mode 100644
index 000000000..0aa4b666c
--- /dev/null
+++ b/src/crewai_tools/rag/core.py
@@ -0,0 +1,232 @@
+import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+from uuid import uuid4
+
+import chromadb
+import litellm
+from pydantic import BaseModel, Field, PrivateAttr
+
+from crewai_tools.tools.rag.rag_tool import Adapter
+from crewai_tools.rag.data_types import DataType
+from crewai_tools.rag.base_loader import BaseLoader
+from crewai_tools.rag.chunkers.base_chunker import BaseChunker
+from crewai_tools.rag.source_content import SourceContent
+from crewai_tools.rag.misc import compute_sha256
+
+logger = logging.getLogger(__name__)
+
+
+class EmbeddingService:
+ def __init__(self, model: str = "text-embedding-3-small", **kwargs):
+ self.model = model
+ self.kwargs = kwargs
+
+ def embed_text(self, text: str) -> List[float]:
+ try:
+ response = litellm.embedding(
+ model=self.model,
+ input=[text],
+ **self.kwargs
+ )
+ return response.data[0]['embedding']
+ except Exception as e:
+ logger.error(f"Error generating embedding: {e}")
+ raise
+
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
+ if not texts:
+ return []
+
+ try:
+ response = litellm.embedding(
+ model=self.model,
+ input=texts,
+ **self.kwargs
+ )
+ return [data['embedding'] for data in response.data]
+ except Exception as e:
+ logger.error(f"Error generating batch embeddings: {e}")
+ raise
+
+
+class Document(BaseModel):
+ id: str = Field(default_factory=lambda: str(uuid4()))
+ content: str
+ metadata: Dict[str, Any] = Field(default_factory=dict)
+ data_type: DataType = DataType.TEXT
+ source: Optional[str] = None
+
+
+class RAG(Adapter):
+ collection_name: str = "crewai_knowledge_base"
+ persist_directory: Optional[str] = None
+ embedding_model: str = "text-embedding-3-large"
+ summarize: bool = False
+ top_k: int = 5
+ embedding_config: Dict[str, Any] = Field(default_factory=dict)
+
+ _client: Any = PrivateAttr()
+ _collection: Any = PrivateAttr()
+ _embedding_service: EmbeddingService = PrivateAttr()
+
+ def model_post_init(self, __context: Any) -> None:
+ try:
+ if self.persist_directory:
+ self._client = chromadb.PersistentClient(path=self.persist_directory)
+ else:
+ self._client = chromadb.Client()
+
+ self._collection = self._client.get_or_create_collection(
+ name=self.collection_name,
+ metadata={"hnsw:space": "cosine", "description": "CrewAI Knowledge Base"}
+ )
+
+ self._embedding_service = EmbeddingService(model=self.embedding_model, **self.embedding_config)
+ except Exception as e:
+ logger.error(f"Failed to initialize ChromaDB: {e}")
+ raise
+
+ super().model_post_init(__context)
+
+ def add(
+ self,
+ content: str | Path,
+ data_type: Optional[Union[str, DataType]] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ loader: Optional[BaseLoader] = None,
+ chunker: Optional[BaseChunker] = None,
+ **kwargs: Any
+ ) -> None:
+ source_content = SourceContent(content)
+
+ data_type = self._get_data_type(data_type=data_type, content=source_content)
+
+ if not loader:
+ loader = data_type.get_loader()
+
+ if not chunker:
+ chunker = data_type.get_chunker()
+
+ loader_result = loader.load(source_content)
+ doc_id = loader_result.doc_id
+
+ existing_doc = self._collection.get(where={"source": source_content.source_ref}, limit=1)
+ existing_doc_id = existing_doc and existing_doc['metadatas'][0]['doc_id'] if existing_doc['metadatas'] else None
+
+ if existing_doc_id == doc_id:
+ logger.warning(f"Document with source {loader_result.source} already exists")
+ return
+
+ # Document with same source ref does exists but the content has changed, deleting the oldest reference
+ if existing_doc_id and existing_doc_id != loader_result.doc_id:
+ logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
+ self._collection.delete(where={"doc_id": existing_doc_id})
+
+ documents = []
+
+ chunks = chunker.chunk(loader_result.content)
+ for i, chunk in enumerate(chunks):
+ doc_metadata = (metadata or {}).copy()
+ doc_metadata['chunk_index'] = i
+ documents.append(Document(
+ id=compute_sha256(chunk),
+ content=chunk,
+ metadata=doc_metadata,
+ data_type=data_type,
+ source=loader_result.source
+ ))
+
+ if not documents:
+ logger.warning("No documents to add")
+ return
+
+ contents = [doc.content for doc in documents]
+ try:
+ embeddings = self._embedding_service.embed_batch(contents)
+ except Exception as e:
+ logger.error(f"Failed to generate embeddings: {e}")
+ return
+
+ ids = [doc.id for doc in documents]
+ metadatas = []
+
+ for doc in documents:
+ doc_metadata = doc.metadata.copy()
+ doc_metadata.update({
+ "data_type": doc.data_type.value,
+ "source": doc.source,
+ "doc_id": doc_id
+ })
+ metadatas.append(doc_metadata)
+
+ try:
+ self._collection.add(
+ ids=ids,
+ embeddings=embeddings,
+ documents=contents,
+ metadatas=metadatas,
+ )
+ logger.info(f"Added {len(documents)} documents to knowledge base")
+ except Exception as e:
+ logger.error(f"Failed to add documents to ChromaDB: {e}")
+
+ def query(self, question: str, where: Optional[Dict[str, Any]] = None) -> str:
+ try:
+ question_embedding = self._embedding_service.embed_text(question)
+
+ results = self._collection.query(
+ query_embeddings=[question_embedding],
+ n_results=self.top_k,
+ where=where,
+ include=["documents", "metadatas", "distances"]
+ )
+
+ if not results or not results.get("documents") or not results["documents"][0]:
+ return "No relevant content found."
+
+ documents = results["documents"][0]
+ metadatas = results.get("metadatas", [None])[0] or []
+ distances = results.get("distances", [None])[0] or []
+
+ # Return sources with relevance scores
+ formatted_results = []
+ for i, doc in enumerate(documents):
+ metadata = metadatas[i] if i < len(metadatas) else {}
+ distance = distances[i] if i < len(distances) else 1.0
+ source = metadata.get("source", "unknown") if metadata else "unknown"
+ score = 1 - distance if distance is not None else 0 # Convert distance to similarity
+ formatted_results.append(f"[Source: {source}, Relevance: {score:.3f}]\n{doc}")
+
+ return "\n\n".join(formatted_results)
+ except Exception as e:
+ logger.error(f"Query failed: {e}")
+ return f"Error querying knowledge base: {e}"
+
+ def delete_collection(self) -> None:
+ try:
+ self._client.delete_collection(self.collection_name)
+ logger.info(f"Deleted collection: {self.collection_name}")
+ except Exception as e:
+ logger.error(f"Failed to delete collection: {e}")
+
+ def get_collection_info(self) -> Dict[str, Any]:
+ try:
+ count = self._collection.count()
+ return {
+ "name": self.collection_name,
+ "count": count,
+ "embedding_model": self.embedding_model
+ }
+ except Exception as e:
+ logger.error(f"Failed to get collection info: {e}")
+ return {"error": str(e)}
+
+ def _get_data_type(self, content: SourceContent, data_type: str | DataType | None = None) -> DataType:
+ try:
+ if isinstance(data_type, str):
+ return DataType(data_type)
+ except Exception as e:
+ pass
+
+ return content.data_type
diff --git a/src/crewai_tools/rag/data_types.py b/src/crewai_tools/rag/data_types.py
new file mode 100644
index 000000000..d2d265cce
--- /dev/null
+++ b/src/crewai_tools/rag/data_types.py
@@ -0,0 +1,137 @@
+from enum import Enum
+from pathlib import Path
+from urllib.parse import urlparse
+import os
+from crewai_tools.rag.chunkers.base_chunker import BaseChunker
+from crewai_tools.rag.base_loader import BaseLoader
+
+class DataType(str, Enum):
+ PDF_FILE = "pdf_file"
+ TEXT_FILE = "text_file"
+ CSV = "csv"
+ JSON = "json"
+ XML = "xml"
+ DOCX = "docx"
+ MDX = "mdx"
+
+ # Database types
+ MYSQL = "mysql"
+ POSTGRES = "postgres"
+
+ # Repository types
+ GITHUB = "github"
+ DIRECTORY = "directory"
+
+ # Web types
+ WEBSITE = "website"
+ DOCS_SITE = "docs_site"
+
+ # Raw types
+ TEXT = "text"
+
+
+ def get_chunker(self) -> BaseChunker:
+ from importlib import import_module
+
+ chunkers = {
+ DataType.TEXT_FILE: ("text_chunker", "TextChunker"),
+ DataType.TEXT: ("text_chunker", "TextChunker"),
+ DataType.DOCX: ("text_chunker", "DocxChunker"),
+ DataType.MDX: ("text_chunker", "MdxChunker"),
+
+ # Structured formats
+ DataType.CSV: ("structured_chunker", "CsvChunker"),
+ DataType.JSON: ("structured_chunker", "JsonChunker"),
+ DataType.XML: ("structured_chunker", "XmlChunker"),
+
+ DataType.WEBSITE: ("web_chunker", "WebsiteChunker"),
+ }
+
+ module_name, class_name = chunkers.get(self, ("default_chunker", "DefaultChunker"))
+ module_path = f"crewai_tools.rag.chunkers.{module_name}"
+
+ try:
+ module = import_module(module_path)
+ return getattr(module, class_name)()
+ except Exception as e:
+ raise ValueError(f"Error loading chunker for {self}: {e}")
+
+ def get_loader(self) -> BaseLoader:
+ from importlib import import_module
+
+ loaders = {
+ DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
+ DataType.TEXT: ("text_loader", "TextLoader"),
+ DataType.XML: ("xml_loader", "XMLLoader"),
+ DataType.WEBSITE: ("webpage_loader", "WebPageLoader"),
+ DataType.MDX: ("mdx_loader", "MDXLoader"),
+ DataType.JSON: ("json_loader", "JSONLoader"),
+ DataType.DOCX: ("docx_loader", "DOCXLoader"),
+ DataType.CSV: ("csv_loader", "CSVLoader"),
+ DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"),
+ }
+
+ module_name, class_name = loaders.get(self, ("text_loader", "TextLoader"))
+ module_path = f"crewai_tools.rag.loaders.{module_name}"
+ try:
+ module = import_module(module_path)
+ return getattr(module, class_name)()
+ except Exception as e:
+ raise ValueError(f"Error loading loader for {self}: {e}")
+
+class DataTypes:
+ @staticmethod
+ def from_content(content: str | Path | None = None) -> DataType:
+ if content is None:
+ return DataType.TEXT
+
+ if isinstance(content, Path):
+ content = str(content)
+
+ is_url = False
+ if isinstance(content, str):
+ try:
+ url = urlparse(content)
+ is_url = (url.scheme and url.netloc) or url.scheme == "file"
+ except Exception:
+ pass
+
+ def get_file_type(path: str) -> DataType | None:
+ mapping = {
+ ".pdf": DataType.PDF_FILE,
+ ".csv": DataType.CSV,
+ ".mdx": DataType.MDX,
+ ".md": DataType.MDX,
+ ".docx": DataType.DOCX,
+ ".json": DataType.JSON,
+ ".xml": DataType.XML,
+ ".txt": DataType.TEXT_FILE,
+ }
+ for ext, dtype in mapping.items():
+ if path.endswith(ext):
+ return dtype
+ return None
+
+ if is_url:
+ dtype = get_file_type(url.path)
+ if dtype:
+ return dtype
+
+ if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
+ return DataType.DOCS_SITE
+ if "github.com" in url.netloc:
+ return DataType.GITHUB
+
+ return DataType.WEBSITE
+
+ if os.path.isfile(content):
+ dtype = get_file_type(content)
+ if dtype:
+ return dtype
+
+ if os.path.exists(content):
+ return DataType.TEXT_FILE
+ elif os.path.isdir(content):
+ return DataType.DIRECTORY
+
+ return DataType.TEXT
diff --git a/src/crewai_tools/rag/loaders/__init__.py b/src/crewai_tools/rag/loaders/__init__.py
new file mode 100644
index 000000000..503651468
--- /dev/null
+++ b/src/crewai_tools/rag/loaders/__init__.py
@@ -0,0 +1,20 @@
+from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
+from crewai_tools.rag.loaders.xml_loader import XMLLoader
+from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
+from crewai_tools.rag.loaders.mdx_loader import MDXLoader
+from crewai_tools.rag.loaders.json_loader import JSONLoader
+from crewai_tools.rag.loaders.docx_loader import DOCXLoader
+from crewai_tools.rag.loaders.csv_loader import CSVLoader
+from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
+
+__all__ = [
+ "TextFileLoader",
+ "TextLoader",
+ "XMLLoader",
+ "WebPageLoader",
+ "MDXLoader",
+ "JSONLoader",
+ "DOCXLoader",
+ "CSVLoader",
+ "DirectoryLoader",
+]
diff --git a/src/crewai_tools/rag/loaders/csv_loader.py b/src/crewai_tools/rag/loaders/csv_loader.py
new file mode 100644
index 000000000..e389123a7
--- /dev/null
+++ b/src/crewai_tools/rag/loaders/csv_loader.py
@@ -0,0 +1,72 @@
+import csv
+from io import StringIO
+
+from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+
+class CSVLoader(BaseLoader):
+ def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
+ source_ref = source_content.source_ref
+
+ content_str = source_content.source
+ if source_content.is_url():
+ content_str = self._load_from_url(content_str, kwargs)
+ elif source_content.path_exists():
+ content_str = self._load_from_file(content_str)
+
+ return self._parse_csv(content_str, source_ref)
+
+
+ def _load_from_url(self, url: str, kwargs: dict) -> str:
+ import requests
+
+ headers = kwargs.get("headers", {
+ "Accept": "text/csv, application/csv, text/plain",
+ "User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)"
+ })
+
+ try:
+ response = requests.get(url, headers=headers, timeout=30)
+ response.raise_for_status()
+ return response.text
+ except Exception as e:
+ raise ValueError(f"Error fetching CSV from URL {url}: {str(e)}")
+
+ def _load_from_file(self, path: str) -> str:
+ with open(path, "r", encoding="utf-8") as file:
+ return file.read()
+
+ def _parse_csv(self, content: str, source_ref: str) -> LoaderResult:
+ try:
+ csv_reader = csv.DictReader(StringIO(content))
+
+ text_parts = []
+ headers = csv_reader.fieldnames
+
+ if headers:
+ text_parts.append("Headers: " + " | ".join(headers))
+ text_parts.append("-" * 50)
+
+ for row_num, row in enumerate(csv_reader, 1):
+ row_text = " | ".join([f"{k}: {v}" for k, v in row.items() if v])
+ text_parts.append(f"Row {row_num}: {row_text}")
+
+ text = "\n".join(text_parts)
+
+ metadata = {
+ "format": "csv",
+ "columns": headers,
+ "rows": len(text_parts) - 2 if headers else 0
+ }
+
+ except Exception as e:
+ text = content
+ metadata = {"format": "csv", "parse_error": str(e)}
+
+ return LoaderResult(
+ content=text,
+ source=source_ref,
+ metadata=metadata,
+ doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
+ )
diff --git a/src/crewai_tools/rag/loaders/directory_loader.py b/src/crewai_tools/rag/loaders/directory_loader.py
new file mode 100644
index 000000000..7bc5f298b
--- /dev/null
+++ b/src/crewai_tools/rag/loaders/directory_loader.py
@@ -0,0 +1,142 @@
+import os
+from pathlib import Path
+from typing import List
+
+from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+
+class DirectoryLoader(BaseLoader):
+ def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
+ """
+ Load and process all files from a directory recursively.
+
+ Args:
+ source: Directory path or URL to a directory listing
+ **kwargs: Additional options:
+ - recursive: bool (default True) - Whether to search recursively
+ - include_extensions: list - Only include files with these extensions
+ - exclude_extensions: list - Exclude files with these extensions
+ - max_files: int - Maximum number of files to process
+ """
+ source_ref = source_content.source_ref
+
+ if source_content.is_url():
+ raise ValueError("URL directory loading is not supported. Please provide a local directory path.")
+
+ if not os.path.exists(source_ref):
+ raise FileNotFoundError(f"Directory does not exist: {source_ref}")
+
+ if not os.path.isdir(source_ref):
+ raise ValueError(f"Path is not a directory: {source_ref}")
+
+ return self._process_directory(source_ref, kwargs)
+
+ def _process_directory(self, dir_path: str, kwargs: dict) -> LoaderResult:
+ recursive = kwargs.get("recursive", True)
+ include_extensions = kwargs.get("include_extensions", None)
+ exclude_extensions = kwargs.get("exclude_extensions", None)
+ max_files = kwargs.get("max_files", None)
+
+ files = self._find_files(dir_path, recursive, include_extensions, exclude_extensions)
+
+ if max_files and len(files) > max_files:
+ files = files[:max_files]
+
+ all_contents = []
+ processed_files = []
+ errors = []
+
+ for file_path in files:
+ try:
+ result = self._process_single_file(file_path)
+ if result:
+ all_contents.append(f"=== File: {file_path} ===\n{result.content}")
+ processed_files.append({
+ "path": file_path,
+ "metadata": result.metadata,
+ "source": result.source
+ })
+ except Exception as e:
+ error_msg = f"Error processing {file_path}: {str(e)}"
+ errors.append(error_msg)
+ all_contents.append(f"=== File: {file_path} (ERROR) ===\n{error_msg}")
+
+ combined_content = "\n\n".join(all_contents)
+
+ metadata = {
+ "format": "directory",
+ "directory_path": dir_path,
+ "total_files": len(files),
+ "processed_files": len(processed_files),
+ "errors": len(errors),
+ "file_details": processed_files,
+ "error_details": errors
+ }
+
+ return LoaderResult(
+ content=combined_content,
+ source=dir_path,
+ metadata=metadata,
+ doc_id=self.generate_doc_id(source_ref=dir_path, content=combined_content)
+ )
+
+ def _find_files(self, dir_path: str, recursive: bool, include_ext: List[str] | None = None, exclude_ext: List[str] | None = None) -> List[str]:
+ """Find all files in directory matching criteria."""
+ files = []
+
+ if recursive:
+ for root, dirs, filenames in os.walk(dir_path):
+ dirs[:] = [d for d in dirs if not d.startswith('.')]
+
+ for filename in filenames:
+ if self._should_include_file(filename, include_ext, exclude_ext):
+ files.append(os.path.join(root, filename))
+ else:
+ try:
+ for item in os.listdir(dir_path):
+ item_path = os.path.join(dir_path, item)
+ if os.path.isfile(item_path) and self._should_include_file(item, include_ext, exclude_ext):
+ files.append(item_path)
+ except PermissionError:
+ pass
+
+ return sorted(files)
+
+ def _should_include_file(self, filename: str, include_ext: List[str] = None, exclude_ext: List[str] = None) -> bool:
+ """Determine if a file should be included based on criteria."""
+ if filename.startswith('.'):
+ return False
+
+ _, ext = os.path.splitext(filename.lower())
+
+ if include_ext:
+ if ext not in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in include_ext]:
+ return False
+
+ if exclude_ext:
+ if ext in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in exclude_ext]:
+ return False
+
+ return True
+
+ def _process_single_file(self, file_path: str) -> LoaderResult:
+ from crewai_tools.rag.data_types import DataTypes
+
+ data_type = DataTypes.from_content(Path(file_path))
+
+ loader = data_type.get_loader()
+
+ result = loader.load(SourceContent(file_path))
+
+ if result.metadata is None:
+ result.metadata = {}
+
+ result.metadata.update({
+ "file_path": file_path,
+ "file_size": os.path.getsize(file_path),
+ "data_type": str(data_type),
+ "loader_type": loader.__class__.__name__
+ })
+
+ return result
diff --git a/src/crewai_tools/rag/loaders/docx_loader.py b/src/crewai_tools/rag/loaders/docx_loader.py
new file mode 100644
index 000000000..2f5df23af
--- /dev/null
+++ b/src/crewai_tools/rag/loaders/docx_loader.py
@@ -0,0 +1,72 @@
+import os
+import tempfile
+
+from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+
+class DOCXLoader(BaseLoader):
+ def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
+ try:
+ from docx import Document as DocxDocument
+ except ImportError:
+ raise ImportError("python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]")
+
+ source_ref = source_content.source_ref
+
+ if source_content.is_url():
+ temp_file = self._download_from_url(source_ref, kwargs)
+ try:
+ return self._load_from_file(temp_file, source_ref, DocxDocument)
+ finally:
+ os.unlink(temp_file)
+ elif source_content.path_exists():
+ return self._load_from_file(source_ref, source_ref, DocxDocument)
+ else:
+ raise ValueError(f"Source must be a valid file path or URL, got: {source_content.source}")
+
+ def _download_from_url(self, url: str, kwargs: dict) -> str:
+ import requests
+
+ headers = kwargs.get("headers", {
+ "Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
+ "User-Agent": "Mozilla/5.0 (compatible; crewai-tools DOCXLoader)"
+ })
+
+ try:
+ response = requests.get(url, headers=headers, timeout=30)
+ response.raise_for_status()
+
+ # Create temporary file to save the DOCX content
+ with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as temp_file:
+ temp_file.write(response.content)
+ return temp_file.name
+ except Exception as e:
+ raise ValueError(f"Error fetching DOCX from URL {url}: {str(e)}")
+
+ def _load_from_file(self, file_path: str, source_ref: str, DocxDocument) -> LoaderResult:
+ try:
+ doc = DocxDocument(file_path)
+
+ text_parts = []
+ for paragraph in doc.paragraphs:
+ if paragraph.text.strip():
+ text_parts.append(paragraph.text)
+
+ content = "\n".join(text_parts)
+
+ metadata = {
+ "format": "docx",
+ "paragraphs": len(doc.paragraphs),
+ "tables": len(doc.tables)
+ }
+
+ return LoaderResult(
+ content=content,
+ source=source_ref,
+ metadata=metadata,
+ doc_id=self.generate_doc_id(source_ref=source_ref, content=content)
+ )
+
+ except Exception as e:
+ raise ValueError(f"Error loading DOCX file: {str(e)}")
diff --git a/src/crewai_tools/rag/loaders/json_loader.py b/src/crewai_tools/rag/loaders/json_loader.py
new file mode 100644
index 000000000..6efab393a
--- /dev/null
+++ b/src/crewai_tools/rag/loaders/json_loader.py
@@ -0,0 +1,69 @@
+import json
+
+from crewai_tools.rag.source_content import SourceContent
+from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
+
+
+class JSONLoader(BaseLoader):
+ def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
+ source_ref = source_content.source_ref
+ content = source_content.source
+
+ if source_content.is_url():
+ content = self._load_from_url(source_ref, kwargs)
+ elif source_content.path_exists():
+ content = self._load_from_file(source_ref)
+
+ return self._parse_json(content, source_ref)
+
+ def _load_from_url(self, url: str, kwargs: dict) -> str:
+ import requests
+
+ headers = kwargs.get("headers", {
+ "Accept": "application/json",
+ "User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)"
+ })
+
+ try:
+ response = requests.get(url, headers=headers, timeout=30)
+ response.raise_for_status()
+ return response.text if not self._is_json_response(response) else json.dumps(response.json(), indent=2)
+ except Exception as e:
+ raise ValueError(f"Error fetching JSON from URL {url}: {str(e)}")
+
+ def _is_json_response(self, response) -> bool:
+ try:
+ response.json()
+ return True
+ except ValueError:
+ return False
+
+ def _load_from_file(self, path: str) -> str:
+ with open(path, "r", encoding="utf-8") as file:
+ return file.read()
+
+ def _parse_json(self, content: str, source_ref: str) -> LoaderResult:
+ try:
+ data = json.loads(content)
+ if isinstance(data, dict):
+ text = "\n".join(f"{k}: {json.dumps(v, indent=0)}" for k, v in data.items())
+ elif isinstance(data, list):
+ text = "\n".join(json.dumps(item, indent=0) for item in data)
+ else:
+ text = json.dumps(data, indent=0)
+
+ metadata = {
+ "format": "json",
+ "type": type(data).__name__,
+ "size": len(data) if isinstance(data, (list, dict)) else 1
+ }
+ except json.JSONDecodeError as e:
+ text = content
+ metadata = {"format": "json", "parse_error": str(e)}
+
+ return LoaderResult(
+ content=text,
+ source=source_ref,
+ metadata=metadata,
+ doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
+ )
diff --git a/src/crewai_tools/rag/loaders/mdx_loader.py b/src/crewai_tools/rag/loaders/mdx_loader.py
new file mode 100644
index 000000000..6da9dc896
--- /dev/null
+++ b/src/crewai_tools/rag/loaders/mdx_loader.py
@@ -0,0 +1,59 @@
+import re
+
+from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+class MDXLoader(BaseLoader):
+ def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
+ source_ref = source_content.source_ref
+ content = source_content.source
+
+ if source_content.is_url():
+ content = self._load_from_url(source_ref, kwargs)
+ elif source_content.path_exists():
+ content = self._load_from_file(source_ref)
+
+ return self._parse_mdx(content, source_ref)
+
+ def _load_from_url(self, url: str, kwargs: dict) -> str:
+ import requests
+
+ headers = kwargs.get("headers", {
+ "Accept": "text/markdown, text/x-markdown, text/plain",
+ "User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)"
+ })
+
+ try:
+ response = requests.get(url, headers=headers, timeout=30)
+ response.raise_for_status()
+ return response.text
+ except Exception as e:
+ raise ValueError(f"Error fetching MDX from URL {url}: {str(e)}")
+
+ def _load_from_file(self, path: str) -> str:
+ with open(path, "r", encoding="utf-8") as file:
+ return file.read()
+
+ def _parse_mdx(self, content: str, source_ref: str) -> LoaderResult:
+ cleaned_content = content
+
+ # Remove import statements
+ cleaned_content = re.sub(r'^import\s+.*?\n', '', cleaned_content, flags=re.MULTILINE)
+
+ # Remove export statements
+ cleaned_content = re.sub(r'^export\s+.*?(?:\n|$)', '', cleaned_content, flags=re.MULTILINE)
+
+ # Remove JSX tags (simple approach)
+ cleaned_content = re.sub(r'<[^>]+>', '', cleaned_content)
+
+ # Clean up extra whitespace
+ cleaned_content = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_content)
+ cleaned_content = cleaned_content.strip()
+
+ metadata = {"format": "mdx"}
+ return LoaderResult(
+ content=cleaned_content,
+ source=source_ref,
+ metadata=metadata,
+ doc_id=self.generate_doc_id(source_ref=source_ref, content=cleaned_content)
+ )
diff --git a/src/crewai_tools/rag/loaders/text_loader.py b/src/crewai_tools/rag/loaders/text_loader.py
new file mode 100644
index 000000000..a97cf29f4
--- /dev/null
+++ b/src/crewai_tools/rag/loaders/text_loader.py
@@ -0,0 +1,28 @@
+
+from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+
+class TextFileLoader(BaseLoader):
+ def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
+ source_ref = source_content.source_ref
+ if not source_content.path_exists():
+ raise FileNotFoundError(f"The following file does not exist: {source_content.source}")
+
+ with open(source_content.source, "r", encoding="utf-8") as file:
+ content = file.read()
+
+ return LoaderResult(
+ content=content,
+ source=source_ref,
+ doc_id=self.generate_doc_id(source_ref=source_ref, content=content)
+ )
+
+
+class TextLoader(BaseLoader):
+ def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
+ return LoaderResult(
+ content=source_content.source,
+ source=source_content.source_ref,
+ doc_id=self.generate_doc_id(content=source_content.source)
+ )
diff --git a/src/crewai_tools/rag/loaders/webpage_loader.py b/src/crewai_tools/rag/loaders/webpage_loader.py
new file mode 100644
index 000000000..4fcb1e0c4
--- /dev/null
+++ b/src/crewai_tools/rag/loaders/webpage_loader.py
@@ -0,0 +1,47 @@
+import re
+import requests
+from bs4 import BeautifulSoup
+
+from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+class WebPageLoader(BaseLoader):
+ def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
+ url = source_content.source
+ headers = kwargs.get("headers", {
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
+ "Accept-Language": "en-US,en;q=0.9",
+ })
+
+ try:
+ response = requests.get(url, timeout=15, headers=headers)
+ response.encoding = response.apparent_encoding
+
+ soup = BeautifulSoup(response.text, "html.parser")
+
+ for script in soup(["script", "style"]):
+ script.decompose()
+
+ text = soup.get_text(" ")
+ text = re.sub("[ \t]+", " ", text)
+ text = re.sub("\\s+\n\\s+", "\n", text)
+ text = text.strip()
+
+ title = soup.title.string.strip() if soup.title and soup.title.string else ""
+ metadata = {
+ "url": url,
+ "title": title,
+ "status_code": response.status_code,
+ "content_type": response.headers.get("content-type", "")
+ }
+
+ return LoaderResult(
+ content=text,
+ source=url,
+ metadata=metadata,
+ doc_id=self.generate_doc_id(source_ref=url, content=text)
+ )
+
+ except Exception as e:
+ raise ValueError(f"Error loading webpage {url}: {str(e)}")
diff --git a/src/crewai_tools/rag/loaders/xml_loader.py b/src/crewai_tools/rag/loaders/xml_loader.py
new file mode 100644
index 000000000..ffafdb9d9
--- /dev/null
+++ b/src/crewai_tools/rag/loaders/xml_loader.py
@@ -0,0 +1,61 @@
+import os
+import xml.etree.ElementTree as ET
+
+from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+class XMLLoader(BaseLoader):
+ def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
+ source_ref = source_content.source_ref
+ content = source_content.source
+
+ if source_content.is_url():
+ content = self._load_from_url(source_ref, kwargs)
+ elif os.path.exists(source_ref):
+ content = self._load_from_file(source_ref)
+
+ return self._parse_xml(content, source_ref)
+
+ def _load_from_url(self, url: str, kwargs: dict) -> str:
+ import requests
+
+ headers = kwargs.get("headers", {
+ "Accept": "application/xml, text/xml, text/plain",
+ "User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)"
+ })
+
+ try:
+ response = requests.get(url, headers=headers, timeout=30)
+ response.raise_for_status()
+ return response.text
+ except Exception as e:
+ raise ValueError(f"Error fetching XML from URL {url}: {str(e)}")
+
+ def _load_from_file(self, path: str) -> str:
+ with open(path, "r", encoding="utf-8") as file:
+ return file.read()
+
+ def _parse_xml(self, content: str, source_ref: str) -> LoaderResult:
+ try:
+ if content.strip().startswith('<'):
+ root = ET.fromstring(content)
+ else:
+ root = ET.parse(source_ref).getroot()
+
+ text_parts = []
+ for text_content in root.itertext():
+ if text_content and text_content.strip():
+ text_parts.append(text_content.strip())
+
+ text = "\n".join(text_parts)
+ metadata = {"format": "xml", "root_tag": root.tag}
+ except ET.ParseError as e:
+ text = content
+ metadata = {"format": "xml", "parse_error": str(e)}
+
+ return LoaderResult(
+ content=text,
+ source=source_ref,
+ metadata=metadata,
+ doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
+ )
diff --git a/src/crewai_tools/rag/misc.py b/src/crewai_tools/rag/misc.py
new file mode 100644
index 000000000..5b95f804e
--- /dev/null
+++ b/src/crewai_tools/rag/misc.py
@@ -0,0 +1,4 @@
+import hashlib
+
+def compute_sha256(content: str) -> str:
+ return hashlib.sha256(content.encode("utf-8")).hexdigest()
diff --git a/src/crewai_tools/rag/source_content.py b/src/crewai_tools/rag/source_content.py
new file mode 100644
index 000000000..59530c8d8
--- /dev/null
+++ b/src/crewai_tools/rag/source_content.py
@@ -0,0 +1,46 @@
+import os
+from urllib.parse import urlparse
+from typing import TYPE_CHECKING
+from pathlib import Path
+from functools import cached_property
+
+from crewai_tools.rag.misc import compute_sha256
+
+if TYPE_CHECKING:
+ from crewai_tools.rag.data_types import DataType
+
+
+class SourceContent:
+ def __init__(self, source: str | Path):
+ self.source = str(source)
+
+ def is_url(self) -> bool:
+ if not isinstance(self.source, str):
+ return False
+ try:
+ parsed_url = urlparse(self.source)
+ return bool(parsed_url.scheme and parsed_url.netloc)
+ except Exception:
+ return False
+
+ def path_exists(self) -> bool:
+ return os.path.exists(self.source)
+
+ @cached_property
+ def data_type(self) -> "DataType":
+ from crewai_tools.rag.data_types import DataTypes
+
+ return DataTypes.from_content(self.source)
+
+ @cached_property
+ def source_ref(self) -> str:
+ """"
+ Returns the source reference for the content.
+ If the content is a URL or a local file, returns the source.
+ Otherwise, returns the hash of the content.
+ """
+
+ if self.is_url() or self.path_exists():
+ return self.source
+
+ return compute_sha256(self.source)
diff --git a/tests/rag/__init__.py b/tests/rag/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/rag/test_csv_loader.py b/tests/rag/test_csv_loader.py
new file mode 100644
index 000000000..596cb4d58
--- /dev/null
+++ b/tests/rag/test_csv_loader.py
@@ -0,0 +1,130 @@
+import os
+import tempfile
+import pytest
+from unittest.mock import patch, Mock
+
+from crewai_tools.rag.loaders.csv_loader import CSVLoader
+from crewai_tools.rag.base_loader import LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+
+@pytest.fixture
+def temp_csv_file():
+ created_files = []
+
+ def _create(content: str):
+ f = tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False)
+ f.write(content)
+ f.close()
+ created_files.append(f.name)
+ return f.name
+
+ yield _create
+
+ for path in created_files:
+ os.unlink(path)
+
+
+class TestCSVLoader:
+ def test_load_csv_from_file(self, temp_csv_file):
+ path = temp_csv_file("name,age,city\nJohn,25,New York\nJane,30,Chicago")
+ loader = CSVLoader()
+ result = loader.load(SourceContent(path))
+
+ assert isinstance(result, LoaderResult)
+ assert "Headers: name | age | city" in result.content
+ assert "Row 1: name: John | age: 25 | city: New York" in result.content
+ assert "Row 2: name: Jane | age: 30 | city: Chicago" in result.content
+ assert result.metadata == {
+ "format": "csv",
+ "columns": ["name", "age", "city"],
+ "rows": 2,
+ }
+ assert result.source == path
+ assert result.doc_id
+
+ def test_load_csv_with_empty_values(self, temp_csv_file):
+ path = temp_csv_file("name,age,city\nJohn,,New York\n,30,")
+ result = CSVLoader().load(SourceContent(path))
+
+ assert "Row 1: name: John | city: New York" in result.content
+ assert "Row 2: age: 30" in result.content
+ assert result.metadata["rows"] == 2
+
+ def test_load_csv_malformed(self, temp_csv_file):
+ path = temp_csv_file("invalid,csv\nunclosed quote \"missing")
+ result = CSVLoader().load(SourceContent(path))
+
+ assert "Headers: invalid | csv" in result.content
+ assert 'Row 1: invalid: unclosed quote "missing' in result.content
+ assert result.metadata["columns"] == ["invalid", "csv"]
+
+ def test_load_csv_empty_file(self, temp_csv_file):
+ path = temp_csv_file("")
+ result = CSVLoader().load(SourceContent(path))
+
+ assert result.content == ""
+ assert result.metadata["rows"] == 0
+
+ def test_load_csv_text_input(self):
+ raw_csv = "col1,col2\nvalue1,value2\nvalue3,value4"
+ result = CSVLoader().load(SourceContent(raw_csv))
+
+ assert "Headers: col1 | col2" in result.content
+ assert "Row 1: col1: value1 | col2: value2" in result.content
+ assert "Row 2: col1: value3 | col2: value4" in result.content
+ assert result.metadata["columns"] == ["col1", "col2"]
+ assert result.metadata["rows"] == 2
+
+ def test_doc_id_is_deterministic(self, temp_csv_file):
+ path = temp_csv_file("name,value\ntest,123")
+ loader = CSVLoader()
+
+ result1 = loader.load(SourceContent(path))
+ result2 = loader.load(SourceContent(path))
+
+ assert result1.doc_id == result2.doc_id
+
+ @patch("requests.get")
+ def test_load_csv_from_url(self, mock_get):
+ mock_get.return_value = Mock(
+ text="name,value\ntest,123",
+ raise_for_status=Mock(return_value=None)
+ )
+
+ result = CSVLoader().load(SourceContent("https://example.com/data.csv"))
+
+ assert "Headers: name | value" in result.content
+ assert "Row 1: name: test | value: 123" in result.content
+ headers = mock_get.call_args[1]["headers"]
+ assert "text/csv" in headers["Accept"]
+ assert "crewai-tools CSVLoader" in headers["User-Agent"]
+
+ @patch("requests.get")
+ def test_load_csv_with_custom_headers(self, mock_get):
+ mock_get.return_value = Mock(
+ text="data,value\ntest,456",
+ raise_for_status=Mock(return_value=None)
+ )
+ headers = {"Authorization": "Bearer token", "Custom-Header": "value"}
+ result = CSVLoader().load(SourceContent("https://example.com/data.csv"), headers=headers)
+
+ assert "Headers: data | value" in result.content
+ assert mock_get.call_args[1]["headers"] == headers
+
+ @patch("requests.get")
+ def test_csv_loader_handles_network_errors(self, mock_get):
+ mock_get.side_effect = Exception("Network error")
+ loader = CSVLoader()
+
+ with pytest.raises(ValueError, match="Error fetching CSV from URL"):
+ loader.load(SourceContent("https://example.com/data.csv"))
+
+ @patch("requests.get")
+ def test_csv_loader_handles_http_error(self, mock_get):
+ mock_get.return_value = Mock()
+ mock_get.return_value.raise_for_status.side_effect = Exception("404 Not Found")
+ loader = CSVLoader()
+
+ with pytest.raises(ValueError, match="Error fetching CSV from URL"):
+ loader.load(SourceContent("https://example.com/notfound.csv"))
diff --git a/tests/rag/test_directory_loader.py b/tests/rag/test_directory_loader.py
new file mode 100644
index 000000000..7ddb38341
--- /dev/null
+++ b/tests/rag/test_directory_loader.py
@@ -0,0 +1,149 @@
+import os
+import tempfile
+import pytest
+
+from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
+from crewai_tools.rag.base_loader import LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+
+@pytest.fixture
+def temp_directory():
+ with tempfile.TemporaryDirectory() as temp_dir:
+ yield temp_dir
+
+
+class TestDirectoryLoader:
+ def _create_file(self, directory, filename, content="test content"):
+ path = os.path.join(directory, filename)
+ with open(path, "w") as f:
+ f.write(content)
+ return path
+
+ def test_load_non_recursive(self, temp_directory):
+ self._create_file(temp_directory, "file1.txt")
+ self._create_file(temp_directory, "file2.txt")
+ subdir = os.path.join(temp_directory, "subdir")
+ os.makedirs(subdir)
+ self._create_file(subdir, "file3.txt")
+
+ loader = DirectoryLoader()
+ result = loader.load(SourceContent(temp_directory), recursive=False)
+
+ assert isinstance(result, LoaderResult)
+ assert "file1.txt" in result.content
+ assert "file2.txt" in result.content
+ assert "file3.txt" not in result.content
+ assert result.metadata["total_files"] == 2
+
+ def test_load_recursive(self, temp_directory):
+ self._create_file(temp_directory, "file1.txt")
+ nested = os.path.join(temp_directory, "subdir", "nested")
+ os.makedirs(nested)
+ self._create_file(os.path.join(temp_directory, "subdir"), "file2.txt")
+ self._create_file(nested, "file3.txt")
+
+ loader = DirectoryLoader()
+ result = loader.load(SourceContent(temp_directory), recursive=True)
+
+ assert all(f"file{i}.txt" in result.content for i in range(1, 4))
+
+ def test_include_and_exclude_extensions(self, temp_directory):
+ self._create_file(temp_directory, "a.txt")
+ self._create_file(temp_directory, "b.py")
+ self._create_file(temp_directory, "c.md")
+
+ loader = DirectoryLoader()
+ result = loader.load(SourceContent(temp_directory), include_extensions=[".txt", ".py"])
+ assert "a.txt" in result.content
+ assert "b.py" in result.content
+ assert "c.md" not in result.content
+
+ result2 = loader.load(SourceContent(temp_directory), exclude_extensions=[".py", ".md"])
+ assert "a.txt" in result2.content
+ assert "b.py" not in result2.content
+ assert "c.md" not in result2.content
+
+ def test_max_files_limit(self, temp_directory):
+ for i in range(5):
+ self._create_file(temp_directory, f"file{i}.txt")
+
+ loader = DirectoryLoader()
+ result = loader.load(SourceContent(temp_directory), max_files=3)
+
+ assert result.metadata["total_files"] == 3
+ assert all(f"file{i}.txt" in result.content for i in range(3))
+
+ def test_hidden_files_and_dirs_excluded(self, temp_directory):
+ self._create_file(temp_directory, "visible.txt", "visible")
+ self._create_file(temp_directory, ".hidden.txt", "hidden")
+
+ hidden_dir = os.path.join(temp_directory, ".hidden")
+ os.makedirs(hidden_dir)
+ self._create_file(hidden_dir, "inside_hidden.txt")
+
+ loader = DirectoryLoader()
+ result = loader.load(SourceContent(temp_directory), recursive=True)
+
+ assert "visible.txt" in result.content
+ assert ".hidden.txt" not in result.content
+ assert "inside_hidden.txt" not in result.content
+
+ def test_directory_does_not_exist(self):
+ loader = DirectoryLoader()
+ with pytest.raises(FileNotFoundError, match="Directory does not exist"):
+ loader.load(SourceContent("/path/does/not/exist"))
+
+ def test_path_is_not_a_directory(self):
+ with tempfile.NamedTemporaryFile() as f:
+ loader = DirectoryLoader()
+ with pytest.raises(ValueError, match="Path is not a directory"):
+ loader.load(SourceContent(f.name))
+
+ def test_url_not_supported(self):
+ loader = DirectoryLoader()
+ with pytest.raises(ValueError, match="URL directory loading is not supported"):
+ loader.load(SourceContent("https://example.com"))
+
+ def test_processing_error_handling(self, temp_directory):
+ self._create_file(temp_directory, "valid.txt")
+ error_file = self._create_file(temp_directory, "error.txt")
+
+ loader = DirectoryLoader()
+ original_method = loader._process_single_file
+
+ def mock(file_path):
+ if "error" in file_path:
+ raise ValueError("Mock error")
+ return original_method(file_path)
+
+ loader._process_single_file = mock
+ result = loader.load(SourceContent(temp_directory))
+
+ assert "valid.txt" in result.content
+ assert "error.txt (ERROR)" in result.content
+ assert result.metadata["errors"] == 1
+ assert len(result.metadata["error_details"]) == 1
+
+ def test_metadata_structure(self, temp_directory):
+ self._create_file(temp_directory, "test.txt", "Sample")
+
+ loader = DirectoryLoader()
+ result = loader.load(SourceContent(temp_directory))
+ metadata = result.metadata
+
+ expected_keys = {
+ "format", "directory_path", "total_files", "processed_files",
+ "errors", "file_details", "error_details"
+ }
+
+ assert expected_keys.issubset(metadata)
+ assert all(k in metadata["file_details"][0] for k in ("path", "metadata", "source"))
+
+ def test_empty_directory(self, temp_directory):
+ loader = DirectoryLoader()
+ result = loader.load(SourceContent(temp_directory))
+
+ assert result.content == ""
+ assert result.metadata["total_files"] == 0
+ assert result.metadata["processed_files"] == 0
diff --git a/tests/rag/test_docx_loader.py b/tests/rag/test_docx_loader.py
new file mode 100644
index 000000000..f95aa0662
--- /dev/null
+++ b/tests/rag/test_docx_loader.py
@@ -0,0 +1,135 @@
+import tempfile
+import pytest
+from unittest.mock import patch, Mock
+
+from crewai_tools.rag.loaders.docx_loader import DOCXLoader
+from crewai_tools.rag.base_loader import LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+
+class TestDOCXLoader:
+ @patch('docx.Document')
+ def test_load_docx_from_file(self, mock_docx_class):
+ mock_doc = Mock()
+ mock_doc.paragraphs = [
+ Mock(text="First paragraph"),
+ Mock(text="Second paragraph"),
+ Mock(text=" ") # Blank paragraph
+ ]
+ mock_doc.tables = []
+ mock_docx_class.return_value = mock_doc
+
+ with tempfile.NamedTemporaryFile(suffix='.docx') as f:
+ loader = DOCXLoader()
+ result = loader.load(SourceContent(f.name))
+
+ assert isinstance(result, LoaderResult)
+ assert result.content == "First paragraph\nSecond paragraph"
+ assert result.metadata == {"format": "docx", "paragraphs": 3, "tables": 0}
+ assert result.source == f.name
+
+ @patch('docx.Document')
+ def test_load_docx_with_tables(self, mock_docx_class):
+ mock_doc = Mock()
+ mock_doc.paragraphs = [Mock(text="Document with table")]
+ mock_doc.tables = [Mock(), Mock()]
+ mock_docx_class.return_value = mock_doc
+
+ with tempfile.NamedTemporaryFile(suffix='.docx') as f:
+ loader = DOCXLoader()
+ result = loader.load(SourceContent(f.name))
+
+ assert result.metadata["tables"] == 2
+
+ @patch('requests.get')
+ @patch('docx.Document')
+ @patch('tempfile.NamedTemporaryFile')
+ @patch('os.unlink')
+ def test_load_docx_from_url(self, mock_unlink, mock_tempfile, mock_docx_class, mock_get):
+ mock_get.return_value = Mock(content=b"fake docx content", raise_for_status=Mock())
+
+ mock_temp = Mock(name="/tmp/temp_docx_file.docx")
+ mock_temp.__enter__ = Mock(return_value=mock_temp)
+ mock_temp.__exit__ = Mock(return_value=None)
+ mock_tempfile.return_value = mock_temp
+
+ mock_doc = Mock()
+ mock_doc.paragraphs = [Mock(text="Content from URL")]
+ mock_doc.tables = []
+ mock_docx_class.return_value = mock_doc
+
+ loader = DOCXLoader()
+ result = loader.load(SourceContent("https://example.com/test.docx"))
+
+ assert "Content from URL" in result.content
+ assert result.source == "https://example.com/test.docx"
+
+ headers = mock_get.call_args[1]['headers']
+ assert "application/vnd.openxmlformats-officedocument.wordprocessingml.document" in headers['Accept']
+ assert "crewai-tools DOCXLoader" in headers['User-Agent']
+
+ mock_temp.write.assert_called_once_with(b"fake docx content")
+
+ @patch('requests.get')
+ @patch('docx.Document')
+ def test_load_docx_from_url_with_custom_headers(self, mock_docx_class, mock_get):
+ mock_get.return_value = Mock(content=b"fake docx content", raise_for_status=Mock())
+ mock_docx_class.return_value = Mock(paragraphs=[], tables=[])
+
+ loader = DOCXLoader()
+ custom_headers = {"Authorization": "Bearer token"}
+
+ with patch('tempfile.NamedTemporaryFile'), patch('os.unlink'):
+ loader.load(SourceContent("https://example.com/test.docx"), headers=custom_headers)
+
+ assert mock_get.call_args[1]['headers'] == custom_headers
+
+ @patch('requests.get')
+ def test_load_docx_url_download_error(self, mock_get):
+ mock_get.side_effect = Exception("Network error")
+
+ loader = DOCXLoader()
+ with pytest.raises(ValueError, match="Error fetching DOCX from URL"):
+ loader.load(SourceContent("https://example.com/test.docx"))
+
+ @patch('requests.get')
+ def test_load_docx_url_http_error(self, mock_get):
+ mock_get.return_value = Mock(raise_for_status=Mock(side_effect=Exception("404 Not Found")))
+
+ loader = DOCXLoader()
+ with pytest.raises(ValueError, match="Error fetching DOCX from URL"):
+ loader.load(SourceContent("https://example.com/notfound.docx"))
+
+ def test_load_docx_invalid_source(self):
+ loader = DOCXLoader()
+ with pytest.raises(ValueError, match="Source must be a valid file path or URL"):
+ loader.load(SourceContent("not_a_file_or_url"))
+
+ @patch('docx.Document')
+ def test_load_docx_parsing_error(self, mock_docx_class):
+ mock_docx_class.side_effect = Exception("Invalid DOCX file")
+
+ with tempfile.NamedTemporaryFile(suffix='.docx') as f:
+ loader = DOCXLoader()
+ with pytest.raises(ValueError, match="Error loading DOCX file"):
+ loader.load(SourceContent(f.name))
+
+ @patch('docx.Document')
+ def test_load_docx_empty_document(self, mock_docx_class):
+ mock_docx_class.return_value = Mock(paragraphs=[], tables=[])
+
+ with tempfile.NamedTemporaryFile(suffix='.docx') as f:
+ loader = DOCXLoader()
+ result = loader.load(SourceContent(f.name))
+
+ assert result.content == ""
+ assert result.metadata == {"paragraphs": 0, "tables": 0, "format": "docx"}
+
+ @patch('docx.Document')
+ def test_docx_doc_id_generation(self, mock_docx_class):
+ mock_docx_class.return_value = Mock(paragraphs=[Mock(text="Consistent content")], tables=[])
+
+ with tempfile.NamedTemporaryFile(suffix='.docx') as f:
+ loader = DOCXLoader()
+ source = SourceContent(f.name)
+ assert loader.load(source).doc_id == loader.load(source).doc_id
diff --git a/tests/rag/test_json_loader.py b/tests/rag/test_json_loader.py
new file mode 100644
index 000000000..b57480e16
--- /dev/null
+++ b/tests/rag/test_json_loader.py
@@ -0,0 +1,180 @@
+import json
+import os
+import tempfile
+import pytest
+from unittest.mock import patch, Mock
+
+from crewai_tools.rag.loaders.json_loader import JSONLoader
+from crewai_tools.rag.base_loader import LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+
+class TestJSONLoader:
+ def _create_temp_json_file(self, data) -> str:
+ """Helper to write JSON data to a temporary file and return its path."""
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
+ json.dump(data, f)
+ return f.name
+
+ def _create_temp_raw_file(self, content: str) -> str:
+ """Helper to write raw content to a temporary file and return its path."""
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
+ f.write(content)
+ return f.name
+
+ def _load_from_path(self, path) -> LoaderResult:
+ loader = JSONLoader()
+ return loader.load(SourceContent(path))
+
+ def test_load_json_dict(self):
+ path = self._create_temp_json_file({"name": "John", "age": 30, "items": ["a", "b", "c"]})
+ try:
+ result = self._load_from_path(path)
+ assert isinstance(result, LoaderResult)
+ assert all(k in result.content for k in ["name", "John", "age", "30"])
+ assert result.metadata == {
+ "format": "json", "type": "dict", "size": 3
+ }
+ assert result.source == path
+ finally:
+ os.unlink(path)
+
+ def test_load_json_list(self):
+ path = self._create_temp_json_file([
+ {"id": 1, "name": "Item 1"},
+ {"id": 2, "name": "Item 2"},
+ ])
+ try:
+ result = self._load_from_path(path)
+ assert result.metadata["type"] == "list"
+ assert result.metadata["size"] == 2
+ assert all(item in result.content for item in ["Item 1", "Item 2"])
+ finally:
+ os.unlink(path)
+
+ @pytest.mark.parametrize("value, expected_type", [
+ ("simple string value", "str"),
+ (42, "int"),
+ ])
+ def test_load_json_primitives(self, value, expected_type):
+ path = self._create_temp_json_file(value)
+ try:
+ result = self._load_from_path(path)
+ assert result.metadata["type"] == expected_type
+ assert result.metadata["size"] == 1
+ assert str(value) in result.content
+ finally:
+ os.unlink(path)
+
+ def test_load_malformed_json(self):
+ path = self._create_temp_raw_file('{"invalid": json,}')
+ try:
+ result = self._load_from_path(path)
+ assert result.metadata["format"] == "json"
+ assert "parse_error" in result.metadata
+ assert result.content == '{"invalid": json,}'
+ finally:
+ os.unlink(path)
+
+ def test_load_empty_file(self):
+ path = self._create_temp_raw_file('')
+ try:
+ result = self._load_from_path(path)
+ assert "parse_error" in result.metadata
+ assert result.content == ''
+ finally:
+ os.unlink(path)
+
+ def test_load_text_input(self):
+ json_text = '{"message": "hello", "count": 5}'
+ loader = JSONLoader()
+ result = loader.load(SourceContent(json_text))
+ assert all(part in result.content for part in ["message", "hello", "count", "5"])
+ assert result.metadata["type"] == "dict"
+ assert result.metadata["size"] == 2
+
+ def test_load_complex_nested_json(self):
+ data = {
+ "users": [
+ {"id": 1, "profile": {"name": "Alice", "settings": {"theme": "dark"}}},
+ {"id": 2, "profile": {"name": "Bob", "settings": {"theme": "light"}}}
+ ],
+ "meta": {"total": 2, "version": "1.0"}
+ }
+ path = self._create_temp_json_file(data)
+ try:
+ result = self._load_from_path(path)
+ for value in ["Alice", "Bob", "dark", "light"]:
+ assert value in result.content
+ assert result.metadata["size"] == 2 # top-level keys
+ finally:
+ os.unlink(path)
+
+ def test_consistent_doc_id(self):
+ path = self._create_temp_json_file({"test": "data"})
+ try:
+ result1 = self._load_from_path(path)
+ result2 = self._load_from_path(path)
+ assert result1.doc_id == result2.doc_id
+ finally:
+ os.unlink(path)
+
+ # ------------------------------
+ # URL-based tests
+ # ------------------------------
+
+ @patch('requests.get')
+ def test_url_response_valid_json(self, mock_get):
+ mock_get.return_value = Mock(
+ text='{"key": "value", "number": 123}',
+ json=Mock(return_value={"key": "value", "number": 123}),
+ raise_for_status=Mock()
+ )
+
+ loader = JSONLoader()
+ result = loader.load(SourceContent("https://api.example.com/data.json"))
+
+ assert all(val in result.content for val in ["key", "value", "number", "123"])
+ headers = mock_get.call_args[1]['headers']
+ assert "application/json" in headers['Accept']
+ assert "crewai-tools JSONLoader" in headers['User-Agent']
+
+ @patch('requests.get')
+ def test_url_response_not_json(self, mock_get):
+ mock_get.return_value = Mock(
+ text='{"key": "value"}',
+ json=Mock(side_effect=ValueError("Not JSON")),
+ raise_for_status=Mock()
+ )
+
+ loader = JSONLoader()
+ result = loader.load(SourceContent("https://example.com/data.json"))
+ assert all(part in result.content for part in ["key", "value"])
+
+ @patch('requests.get')
+ def test_url_with_custom_headers(self, mock_get):
+ mock_get.return_value = Mock(
+ text='{"data": "test"}',
+ json=Mock(return_value={"data": "test"}),
+ raise_for_status=Mock()
+ )
+ headers = {"Authorization": "Bearer token", "Custom-Header": "value"}
+
+ loader = JSONLoader()
+ loader.load(SourceContent("https://api.example.com/data.json"), headers=headers)
+
+ assert mock_get.call_args[1]['headers'] == headers
+
+ @patch('requests.get')
+ def test_url_network_failure(self, mock_get):
+ mock_get.side_effect = Exception("Network error")
+ loader = JSONLoader()
+ with pytest.raises(ValueError, match="Error fetching JSON from URL"):
+ loader.load(SourceContent("https://api.example.com/data.json"))
+
+ @patch('requests.get')
+ def test_url_http_error(self, mock_get):
+ mock_get.return_value = Mock(raise_for_status=Mock(side_effect=Exception("404")))
+ loader = JSONLoader()
+ with pytest.raises(ValueError, match="Error fetching JSON from URL"):
+ loader.load(SourceContent("https://api.example.com/404.json"))
diff --git a/tests/rag/test_mdx_loader.py b/tests/rag/test_mdx_loader.py
new file mode 100644
index 000000000..ef7944c28
--- /dev/null
+++ b/tests/rag/test_mdx_loader.py
@@ -0,0 +1,176 @@
+import os
+import tempfile
+import pytest
+from unittest.mock import patch, Mock
+
+from crewai_tools.rag.loaders.mdx_loader import MDXLoader
+from crewai_tools.rag.base_loader import LoaderResult
+from crewai_tools.rag.source_content import SourceContent
+
+
+class TestMDXLoader:
+
+ def _write_temp_mdx(self, content):
+ f = tempfile.NamedTemporaryFile(mode='w', suffix='.mdx', delete=False)
+ f.write(content)
+ f.close()
+ return f.name
+
+ def _load_from_file(self, content):
+ path = self._write_temp_mdx(content)
+ try:
+ loader = MDXLoader()
+ return loader.load(SourceContent(path)), path
+ finally:
+ os.unlink(path)
+
+ def test_load_basic_mdx_file(self):
+ content = """
+import Component from './Component'
+export const meta = { title: 'Test' }
+
+# Test MDX File
+
+This is a **markdown** file with JSX.
+
+
Nested content
+No markdown here
+"])
+ assert "Only JSX content" in result.content
+ assert "No markdown here" in result.content
+
+ @patch('requests.get')
+ def test_load_mdx_from_url(self, mock_get):
+ mock_get.return_value = Mock(text="# MDX from URL\n\nContent here.\n\n
Test content
") + mock_bs.return_value = self.setup_mock_soup("Test content", title="Test Page") + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com")) + + assert isinstance(result, LoaderResult) + assert result.content == "Test content" + assert result.metadata["title"] == "Test Page" + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get): + html = """ +Visible content
+ """ + mock_get.return_value = self.setup_mock_response(html) + scripts = [Mock(), Mock()] + styles = [Mock()] + for el in scripts + styles: + el.decompose = Mock() + mock_bs.return_value = self.setup_mock_soup("Page with Scripts Visible content", title="Page with Scripts", script_style_elements=scripts + styles) + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com/with-scripts")) + + assert "Visible content" in result.content + for el in scripts + styles: + el.decompose.assert_called_once() + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_text_cleaning_and_title_handling(self, mock_bs, mock_get): + mock_get.return_value = self.setup_mock_response("Messy text
") + mock_bs.return_value = self.setup_mock_soup("Text with extra spaces\n\n More\t text \n\n", title=None) + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com/messy-text")) + assert result.content is not None + assert result.metadata["title"] == "" + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_empty_or_missing_title(self, mock_bs, mock_get): + for title in [None, ""]: + mock_get.return_value = self.setup_mock_response("Test content
") + mock_bs.return_value = self.setup_mock_soup("Test content", title="Test Page") + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com")) + + assert isinstance(result, LoaderResult) + assert result.content == "Test content" + assert result.metadata["title"] == "Test Page" + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get): + html = """ +Visible content
+ """ + mock_get.return_value = self.setup_mock_response(html) + scripts = [Mock(), Mock()] + styles = [Mock()] + for el in scripts + styles: + el.decompose = Mock() + mock_bs.return_value = self.setup_mock_soup("Page with Scripts Visible content", title="Page with Scripts", script_style_elements=scripts + styles) + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com/with-scripts")) + + assert "Visible content" in result.content + for el in scripts + styles: + el.decompose.assert_called_once() + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_text_cleaning_and_title_handling(self, mock_bs, mock_get): + mock_get.return_value = self.setup_mock_response("Messy text
") + mock_bs.return_value = self.setup_mock_soup("Text with extra spaces\n\n More\t text \n\n", title=None) + + loader = WebPageLoader() + result = loader.load(SourceContent("https://example.com/messy-text")) + assert result.content is not None + assert result.metadata["title"] == "" + + @patch('requests.get') + @patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') + def test_empty_or_missing_title(self, mock_bs, mock_get): + for title in [None, ""]: + mock_get.return_value = self.setup_mock_response("