mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 07:42:40 +00:00
feat: merge latest changes from crewAI-tools main into packages/tools
- Merged upstream changes from crewAI-tools main branch - Resolved conflicts due to monorepo structure (crewai_tools -> src/crewai_tools) - Removed deprecated embedchain adapters - Added new RAG loaders and crewai_rag_adapter - Consolidated dependencies in pyproject.toml Fixed critical linting issues: - Added ClassVar annotations for mutable class attributes - Added timeouts to requests calls (30s default) - Fixed exception handling with proper 'from' clauses - Added noqa comments for public API functions (backward compatibility) - Updated ruff config to ignore expected patterns: - F401 in __init__ files (intentional re-exports) - S101 in test files (assertions are expected) - S607 for subprocess calls (uv/pip commands are safe) Remaining issues are from upstream code and will be addressed in separate PRs.
This commit is contained in:
@@ -3,6 +3,6 @@ from crewai_tools.rag.data_types import DataType
|
||||
|
||||
__all__ = [
|
||||
"RAG",
|
||||
"EmbeddingService",
|
||||
"DataType",
|
||||
"EmbeddingService",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.misc import compute_sha256
|
||||
@@ -9,19 +10,22 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class LoaderResult(BaseModel):
|
||||
content: str = Field(description="The text content of the source")
|
||||
source: str = Field(description="The source of the content", default="unknown")
|
||||
metadata: Dict[str, Any] = Field(description="The metadata of the source", default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(
|
||||
description="The metadata of the source", default_factory=dict
|
||||
)
|
||||
doc_id: str = Field(description="The id of the document")
|
||||
|
||||
|
||||
class BaseLoader(ABC):
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, config: dict[str, Any] | None = None):
|
||||
self.config = config or {}
|
||||
|
||||
@abstractmethod
|
||||
def load(self, content: SourceContent, **kwargs) -> LoaderResult:
|
||||
...
|
||||
def load(self, content: SourceContent, **kwargs) -> LoaderResult: ...
|
||||
|
||||
def generate_doc_id(self, source_ref: str | None = None, content: str | None = None) -> str:
|
||||
def generate_doc_id(
|
||||
self, source_ref: str | None = None, content: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique document id based on the source reference and content.
|
||||
If the source reference is not provided, the content is used as the source reference.
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from crewai_tools.rag.chunkers.default_chunker import DefaultChunker
|
||||
from crewai_tools.rag.chunkers.text_chunker import TextChunker, DocxChunker, MdxChunker
|
||||
from crewai_tools.rag.chunkers.structured_chunker import CsvChunker, JsonChunker, XmlChunker
|
||||
from crewai_tools.rag.chunkers.structured_chunker import (
|
||||
CsvChunker,
|
||||
JsonChunker,
|
||||
XmlChunker,
|
||||
)
|
||||
from crewai_tools.rag.chunkers.text_chunker import DocxChunker, MdxChunker, TextChunker
|
||||
|
||||
__all__ = [
|
||||
"BaseChunker",
|
||||
"DefaultChunker",
|
||||
"TextChunker",
|
||||
"DocxChunker",
|
||||
"MdxChunker",
|
||||
"CsvChunker",
|
||||
"DefaultChunker",
|
||||
"DocxChunker",
|
||||
"JsonChunker",
|
||||
"MdxChunker",
|
||||
"TextChunker",
|
||||
"XmlChunker",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Optional
|
||||
import re
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitter:
|
||||
"""
|
||||
A text splitter that recursively splits text based on a hierarchy of separators.
|
||||
@@ -10,7 +10,7 @@ class RecursiveCharacterTextSplitter:
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
separators: Optional[List[str]] = None,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -23,7 +23,9 @@ class RecursiveCharacterTextSplitter:
|
||||
keep_separator: Whether to keep the separator in the split text
|
||||
"""
|
||||
if chunk_overlap >= chunk_size:
|
||||
raise ValueError(f"Chunk overlap ({chunk_overlap}) cannot be >= chunk size ({chunk_size})")
|
||||
raise ValueError(
|
||||
f"Chunk overlap ({chunk_overlap}) cannot be >= chunk size ({chunk_size})"
|
||||
)
|
||||
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
@@ -36,10 +38,10 @@ class RecursiveCharacterTextSplitter:
|
||||
"",
|
||||
]
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
return self._split_text(text, self._separators)
|
||||
|
||||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||
def _split_text(self, text: str, separators: list[str]) -> list[str]:
|
||||
separator = separators[-1]
|
||||
new_separators = []
|
||||
|
||||
@@ -49,7 +51,7 @@ class RecursiveCharacterTextSplitter:
|
||||
break
|
||||
if re.search(re.escape(sep), text):
|
||||
separator = sep
|
||||
new_separators = separators[i + 1:]
|
||||
new_separators = separators[i + 1 :]
|
||||
break
|
||||
|
||||
splits = self._split_text_with_separator(text, separator)
|
||||
@@ -68,7 +70,7 @@ class RecursiveCharacterTextSplitter:
|
||||
|
||||
return self._merge_splits(good_splits, separator)
|
||||
|
||||
def _split_text_with_separator(self, text: str, separator: str) -> List[str]:
|
||||
def _split_text_with_separator(self, text: str, separator: str) -> list[str]:
|
||||
if separator == "":
|
||||
return list(text)
|
||||
|
||||
@@ -90,16 +92,15 @@ class RecursiveCharacterTextSplitter:
|
||||
splits[-1] += separator
|
||||
|
||||
return [s for s in splits if s]
|
||||
else:
|
||||
return text.split(separator)
|
||||
return text.split(separator)
|
||||
|
||||
def _split_by_characters(self, text: str) -> List[str]:
|
||||
def _split_by_characters(self, text: str) -> list[str]:
|
||||
chunks = []
|
||||
for i in range(0, len(text), self._chunk_size):
|
||||
chunks.append(text[i:i + self._chunk_size])
|
||||
chunks.append(text[i : i + self._chunk_size])
|
||||
return chunks
|
||||
|
||||
def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
|
||||
def _merge_splits(self, splits: list[str], separator: str) -> list[str]:
|
||||
"""Merge splits into chunks with proper overlap."""
|
||||
docs = []
|
||||
current_doc = []
|
||||
@@ -112,7 +113,10 @@ class RecursiveCharacterTextSplitter:
|
||||
if separator == "":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
if self._keep_separator and separator == " ":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
@@ -133,15 +137,25 @@ class RecursiveCharacterTextSplitter:
|
||||
if separator == "":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
if self._keep_separator and separator == " ":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class BaseChunker:
|
||||
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the Chunker
|
||||
|
||||
@@ -159,8 +173,7 @@ class BaseChunker:
|
||||
keep_separator=keep_separator,
|
||||
)
|
||||
|
||||
|
||||
def chunk(self, text: str) -> List[str]:
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class DefaultChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 20, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 2000,
|
||||
chunk_overlap: int = 20,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -1,49 +1,66 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class CsvChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 1200, chunk_overlap: int = 100, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 1200,
|
||||
chunk_overlap: int = 100,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\nRow ", # Row boundaries (from CSVLoader format)
|
||||
"\n", # Line breaks
|
||||
" | ", # Column separators
|
||||
", ", # Comma separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\nRow ", # Row boundaries (from CSVLoader format)
|
||||
"\n", # Line breaks
|
||||
" | ", # Column separators
|
||||
", ", # Comma separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class JsonChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 2000,
|
||||
chunk_overlap: int = 200,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n", # Object/array boundaries
|
||||
"\n", # Line breaks
|
||||
"},", # Object endings
|
||||
"],", # Array endings
|
||||
", ", # Property separators
|
||||
": ", # Key-value separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n", # Object/array boundaries
|
||||
"\n", # Line breaks
|
||||
"},", # Object endings
|
||||
"],", # Array endings
|
||||
", ", # Property separators
|
||||
": ", # Key-value separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class XmlChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 2500,
|
||||
chunk_overlap: int = 250,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n", # Element boundaries
|
||||
"\n", # Line breaks
|
||||
">", # Tag endings
|
||||
". ", # Sentence endings (for text content)
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
", ", # Comma separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n", # Element boundaries
|
||||
"\n", # Line breaks
|
||||
">", # Tag endings
|
||||
". ", # Sentence endings (for text content)
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
", ", # Comma separators
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -1,59 +1,76 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class TextChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 1500, chunk_overlap: int = 150, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 1500,
|
||||
chunk_overlap: int = 150,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Multiple line breaks (sections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class DocxChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 2500,
|
||||
chunk_overlap: int = 250,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Multiple line breaks (major sections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
|
||||
class MdxChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 3000, chunk_overlap: int = 300, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 3000,
|
||||
chunk_overlap: int = 300,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n## ", # H2 headers (major sections)
|
||||
"\n## ", # H2 headers (major sections)
|
||||
"\n### ", # H3 headers (subsections)
|
||||
"\n#### ", # H4 headers (sub-subsections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n```", # Code block boundaries
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n#### ", # H4 headers (sub-subsections)
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n```", # Code block boundaries
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -1,20 +1,25 @@
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class WebsiteChunker(BaseChunker):
|
||||
def __init__(self, chunk_size: int = 2500, chunk_overlap: int = 250, separators: Optional[List[str]] = None, keep_separator: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 2500,
|
||||
chunk_overlap: int = 250,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
):
|
||||
if separators is None:
|
||||
separators = [
|
||||
"\n\n\n", # Major section breaks
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
"\n\n", # Paragraph breaks
|
||||
"\n", # Line breaks
|
||||
". ", # Sentence endings
|
||||
"! ", # Exclamation endings
|
||||
"? ", # Question endings
|
||||
"; ", # Semicolon breaks
|
||||
", ", # Comma breaks
|
||||
" ", # Word breaks
|
||||
"", # Character level
|
||||
]
|
||||
super().__init__(chunk_size, chunk_overlap, separators, keep_separator)
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
import litellm
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import compute_sha256
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,29 +22,21 @@ class EmbeddingService:
|
||||
self.model = model
|
||||
self.kwargs = kwargs
|
||||
|
||||
def embed_text(self, text: str) -> List[float]:
|
||||
def embed_text(self, text: str) -> list[float]:
|
||||
try:
|
||||
response = litellm.embedding(
|
||||
model=self.model,
|
||||
input=[text],
|
||||
**self.kwargs
|
||||
)
|
||||
return response.data[0]['embedding']
|
||||
response = litellm.embedding(model=self.model, input=[text], **self.kwargs)
|
||||
return response.data[0]["embedding"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embedding: {e}")
|
||||
raise
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = litellm.embedding(
|
||||
model=self.model,
|
||||
input=texts,
|
||||
**self.kwargs
|
||||
)
|
||||
return [data['embedding'] for data in response.data]
|
||||
response = litellm.embedding(model=self.model, input=texts, **self.kwargs)
|
||||
return [data["embedding"] for data in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating batch embeddings: {e}")
|
||||
raise
|
||||
@@ -53,18 +45,18 @@ class EmbeddingService:
|
||||
class Document(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
content: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
data_type: DataType = DataType.TEXT
|
||||
source: Optional[str] = None
|
||||
source: str | None = None
|
||||
|
||||
|
||||
class RAG(Adapter):
|
||||
collection_name: str = "crewai_knowledge_base"
|
||||
persist_directory: Optional[str] = None
|
||||
persist_directory: str | None = None
|
||||
embedding_model: str = "text-embedding-3-large"
|
||||
summarize: bool = False
|
||||
top_k: int = 5
|
||||
embedding_config: Dict[str, Any] = Field(default_factory=dict)
|
||||
embedding_config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
_client: Any = PrivateAttr()
|
||||
_collection: Any = PrivateAttr()
|
||||
@@ -79,10 +71,15 @@ class RAG(Adapter):
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"hnsw:space": "cosine", "description": "CrewAI Knowledge Base"}
|
||||
metadata={
|
||||
"hnsw:space": "cosine",
|
||||
"description": "CrewAI Knowledge Base",
|
||||
},
|
||||
)
|
||||
|
||||
self._embedding_service = EmbeddingService(model=self.embedding_model, **self.embedding_config)
|
||||
self._embedding_service = EmbeddingService(
|
||||
model=self.embedding_model, **self.embedding_config
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ChromaDB: {e}")
|
||||
raise
|
||||
@@ -92,11 +89,11 @@ class RAG(Adapter):
|
||||
def add(
|
||||
self,
|
||||
content: str | Path,
|
||||
data_type: Optional[Union[str, DataType]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
loader: Optional[BaseLoader] = None,
|
||||
chunker: Optional[BaseChunker] = None,
|
||||
**kwargs: Any
|
||||
data_type: str | DataType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
loader: BaseLoader | None = None,
|
||||
chunker: BaseChunker | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
source_content = SourceContent(content)
|
||||
|
||||
@@ -111,11 +108,19 @@ class RAG(Adapter):
|
||||
loader_result = loader.load(source_content)
|
||||
doc_id = loader_result.doc_id
|
||||
|
||||
existing_doc = self._collection.get(where={"source": source_content.source_ref}, limit=1)
|
||||
existing_doc_id = existing_doc and existing_doc['metadatas'][0]['doc_id'] if existing_doc['metadatas'] else None
|
||||
existing_doc = self._collection.get(
|
||||
where={"source": source_content.source_ref}, limit=1
|
||||
)
|
||||
existing_doc_id = (
|
||||
existing_doc and existing_doc["metadatas"][0]["doc_id"]
|
||||
if existing_doc["metadatas"]
|
||||
else None
|
||||
)
|
||||
|
||||
if existing_doc_id == doc_id:
|
||||
logger.warning(f"Document with source {loader_result.source} already exists")
|
||||
logger.warning(
|
||||
f"Document with source {loader_result.source} already exists"
|
||||
)
|
||||
return
|
||||
|
||||
# Document with same source ref does exists but the content has changed, deleting the oldest reference
|
||||
@@ -128,14 +133,16 @@ class RAG(Adapter):
|
||||
chunks = chunker.chunk(loader_result.content)
|
||||
for i, chunk in enumerate(chunks):
|
||||
doc_metadata = (metadata or {}).copy()
|
||||
doc_metadata['chunk_index'] = i
|
||||
documents.append(Document(
|
||||
id=compute_sha256(chunk),
|
||||
content=chunk,
|
||||
metadata=doc_metadata,
|
||||
data_type=data_type,
|
||||
source=loader_result.source
|
||||
))
|
||||
doc_metadata["chunk_index"] = i
|
||||
documents.append(
|
||||
Document(
|
||||
id=compute_sha256(chunk),
|
||||
content=chunk,
|
||||
metadata=doc_metadata,
|
||||
data_type=data_type,
|
||||
source=loader_result.source,
|
||||
)
|
||||
)
|
||||
|
||||
if not documents:
|
||||
logger.warning("No documents to add")
|
||||
@@ -153,11 +160,13 @@ class RAG(Adapter):
|
||||
|
||||
for doc in documents:
|
||||
doc_metadata = doc.metadata.copy()
|
||||
doc_metadata.update({
|
||||
"data_type": doc.data_type.value,
|
||||
"source": doc.source,
|
||||
"doc_id": doc_id
|
||||
})
|
||||
doc_metadata.update(
|
||||
{
|
||||
"data_type": doc.data_type.value,
|
||||
"source": doc.source,
|
||||
"doc_id": doc_id,
|
||||
}
|
||||
)
|
||||
metadatas.append(doc_metadata)
|
||||
|
||||
try:
|
||||
@@ -171,7 +180,7 @@ class RAG(Adapter):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
|
||||
def query(self, question: str, where: Optional[Dict[str, Any]] = None) -> str:
|
||||
def query(self, question: str, where: dict[str, Any] | None = None) -> str:
|
||||
try:
|
||||
question_embedding = self._embedding_service.embed_text(question)
|
||||
|
||||
@@ -179,10 +188,14 @@ class RAG(Adapter):
|
||||
query_embeddings=[question_embedding],
|
||||
n_results=self.top_k,
|
||||
where=where,
|
||||
include=["documents", "metadatas", "distances"]
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
if not results or not results.get("documents") or not results["documents"][0]:
|
||||
if (
|
||||
not results
|
||||
or not results.get("documents")
|
||||
or not results["documents"][0]
|
||||
):
|
||||
return "No relevant content found."
|
||||
|
||||
documents = results["documents"][0]
|
||||
@@ -195,8 +208,12 @@ class RAG(Adapter):
|
||||
metadata = metadatas[i] if i < len(metadatas) else {}
|
||||
distance = distances[i] if i < len(distances) else 1.0
|
||||
source = metadata.get("source", "unknown") if metadata else "unknown"
|
||||
score = 1 - distance if distance is not None else 0 # Convert distance to similarity
|
||||
formatted_results.append(f"[Source: {source}, Relevance: {score:.3f}]\n{doc}")
|
||||
score = (
|
||||
1 - distance if distance is not None else 0
|
||||
) # Convert distance to similarity
|
||||
formatted_results.append(
|
||||
f"[Source: {source}, Relevance: {score:.3f}]\n{doc}"
|
||||
)
|
||||
|
||||
return "\n\n".join(formatted_results)
|
||||
except Exception as e:
|
||||
@@ -210,23 +227,25 @@ class RAG(Adapter):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
||||
def get_collection_info(self) -> Dict[str, Any]:
|
||||
def get_collection_info(self) -> dict[str, Any]:
|
||||
try:
|
||||
count = self._collection.count()
|
||||
return {
|
||||
"name": self.collection_name,
|
||||
"count": count,
|
||||
"embedding_model": self.embedding_model
|
||||
"embedding_model": self.embedding_model,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get collection info: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _get_data_type(self, content: SourceContent, data_type: str | DataType | None = None) -> DataType:
|
||||
def _get_data_type(
|
||||
self, content: SourceContent, data_type: str | DataType | None = None
|
||||
) -> DataType:
|
||||
try:
|
||||
if isinstance(data_type, str):
|
||||
return DataType(data_type)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return content.data_type
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
PDF_FILE = "pdf_file"
|
||||
@@ -25,29 +27,38 @@ class DataType(str, Enum):
|
||||
# Web types
|
||||
WEBSITE = "website"
|
||||
DOCS_SITE = "docs_site"
|
||||
YOUTUBE_VIDEO = "youtube_video"
|
||||
YOUTUBE_CHANNEL = "youtube_channel"
|
||||
|
||||
# Raw types
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
def get_chunker(self) -> BaseChunker:
|
||||
from importlib import import_module
|
||||
|
||||
chunkers = {
|
||||
DataType.PDF_FILE: ("text_chunker", "TextChunker"),
|
||||
DataType.TEXT_FILE: ("text_chunker", "TextChunker"),
|
||||
DataType.TEXT: ("text_chunker", "TextChunker"),
|
||||
DataType.DOCX: ("text_chunker", "DocxChunker"),
|
||||
DataType.MDX: ("text_chunker", "MdxChunker"),
|
||||
|
||||
# Structured formats
|
||||
DataType.CSV: ("structured_chunker", "CsvChunker"),
|
||||
DataType.JSON: ("structured_chunker", "JsonChunker"),
|
||||
DataType.XML: ("structured_chunker", "XmlChunker"),
|
||||
|
||||
DataType.WEBSITE: ("web_chunker", "WebsiteChunker"),
|
||||
DataType.DIRECTORY: ("text_chunker", "TextChunker"),
|
||||
DataType.YOUTUBE_VIDEO: ("text_chunker", "TextChunker"),
|
||||
DataType.YOUTUBE_CHANNEL: ("text_chunker", "TextChunker"),
|
||||
DataType.GITHUB: ("text_chunker", "TextChunker"),
|
||||
DataType.DOCS_SITE: ("text_chunker", "TextChunker"),
|
||||
DataType.MYSQL: ("text_chunker", "TextChunker"),
|
||||
DataType.POSTGRES: ("text_chunker", "TextChunker"),
|
||||
}
|
||||
|
||||
module_name, class_name = chunkers.get(self, ("default_chunker", "DefaultChunker"))
|
||||
if self not in chunkers:
|
||||
raise ValueError(f"No chunker defined for {self}")
|
||||
module_name, class_name = chunkers[self]
|
||||
module_path = f"crewai_tools.rag.chunkers.{module_name}"
|
||||
|
||||
try:
|
||||
@@ -60,6 +71,7 @@ class DataType(str, Enum):
|
||||
from importlib import import_module
|
||||
|
||||
loaders = {
|
||||
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
|
||||
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||
DataType.TEXT: ("text_loader", "TextLoader"),
|
||||
DataType.XML: ("xml_loader", "XMLLoader"),
|
||||
@@ -69,9 +81,20 @@ class DataType(str, Enum):
|
||||
DataType.DOCX: ("docx_loader", "DOCXLoader"),
|
||||
DataType.CSV: ("csv_loader", "CSVLoader"),
|
||||
DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"),
|
||||
DataType.YOUTUBE_VIDEO: ("youtube_video_loader", "YoutubeVideoLoader"),
|
||||
DataType.YOUTUBE_CHANNEL: (
|
||||
"youtube_channel_loader",
|
||||
"YoutubeChannelLoader",
|
||||
),
|
||||
DataType.GITHUB: ("github_loader", "GithubLoader"),
|
||||
DataType.DOCS_SITE: ("docs_site_loader", "DocsSiteLoader"),
|
||||
DataType.MYSQL: ("mysql_loader", "MySQLLoader"),
|
||||
DataType.POSTGRES: ("postgres_loader", "PostgresLoader"),
|
||||
}
|
||||
|
||||
module_name, class_name = loaders.get(self, ("text_loader", "TextLoader"))
|
||||
if self not in loaders:
|
||||
raise ValueError(f"No loader defined for {self}")
|
||||
module_name, class_name = loaders[self]
|
||||
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
@@ -79,6 +102,7 @@ class DataType(str, Enum):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading loader for {self}: {e}")
|
||||
|
||||
|
||||
class DataTypes:
|
||||
@staticmethod
|
||||
def from_content(content: str | Path | None = None) -> DataType:
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
|
||||
from crewai_tools.rag.loaders.xml_loader import XMLLoader
|
||||
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
|
||||
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
|
||||
from crewai_tools.rag.loaders.json_loader import JSONLoader
|
||||
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
|
||||
from crewai_tools.rag.loaders.csv_loader import CSVLoader
|
||||
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
|
||||
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
|
||||
from crewai_tools.rag.loaders.json_loader import JSONLoader
|
||||
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
|
||||
from crewai_tools.rag.loaders.pdf_loader import PDFLoader
|
||||
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
|
||||
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
|
||||
from crewai_tools.rag.loaders.xml_loader import XMLLoader
|
||||
from crewai_tools.rag.loaders.youtube_channel_loader import YoutubeChannelLoader
|
||||
from crewai_tools.rag.loaders.youtube_video_loader import YoutubeVideoLoader
|
||||
|
||||
__all__ = [
|
||||
"CSVLoader",
|
||||
"DOCXLoader",
|
||||
"DirectoryLoader",
|
||||
"JSONLoader",
|
||||
"MDXLoader",
|
||||
"PDFLoader",
|
||||
"TextFileLoader",
|
||||
"TextLoader",
|
||||
"XMLLoader",
|
||||
"WebPageLoader",
|
||||
"MDXLoader",
|
||||
"JSONLoader",
|
||||
"DOCXLoader",
|
||||
"CSVLoader",
|
||||
"DirectoryLoader",
|
||||
"XMLLoader",
|
||||
"YoutubeChannelLoader",
|
||||
"YoutubeVideoLoader",
|
||||
]
|
||||
|
||||
@@ -17,21 +17,23 @@ class CSVLoader(BaseLoader):
|
||||
|
||||
return self._parse_csv(content_str, source_ref)
|
||||
|
||||
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "text/csv, application/csv, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)"
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "text/csv, application/csv, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching CSV from URL {url}: {str(e)}")
|
||||
raise ValueError(f"Error fetching CSV from URL {url}: {e!s}")
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
@@ -57,7 +59,7 @@ class CSVLoader(BaseLoader):
|
||||
metadata = {
|
||||
"format": "csv",
|
||||
"columns": headers,
|
||||
"rows": len(text_parts) - 2 if headers else 0
|
||||
"rows": len(text_parts) - 2 if headers else 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -68,5 +70,5 @@ class CSVLoader(BaseLoader):
|
||||
content=text,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
@@ -22,7 +21,9 @@ class DirectoryLoader(BaseLoader):
|
||||
source_ref = source_content.source_ref
|
||||
|
||||
if source_content.is_url():
|
||||
raise ValueError("URL directory loading is not supported. Please provide a local directory path.")
|
||||
raise ValueError(
|
||||
"URL directory loading is not supported. Please provide a local directory path."
|
||||
)
|
||||
|
||||
if not os.path.exists(source_ref):
|
||||
raise FileNotFoundError(f"Directory does not exist: {source_ref}")
|
||||
@@ -38,7 +39,9 @@ class DirectoryLoader(BaseLoader):
|
||||
exclude_extensions = kwargs.get("exclude_extensions", None)
|
||||
max_files = kwargs.get("max_files", None)
|
||||
|
||||
files = self._find_files(dir_path, recursive, include_extensions, exclude_extensions)
|
||||
files = self._find_files(
|
||||
dir_path, recursive, include_extensions, exclude_extensions
|
||||
)
|
||||
|
||||
if max_files and len(files) > max_files:
|
||||
files = files[:max_files]
|
||||
@@ -52,13 +55,15 @@ class DirectoryLoader(BaseLoader):
|
||||
result = self._process_single_file(file_path)
|
||||
if result:
|
||||
all_contents.append(f"=== File: {file_path} ===\n{result.content}")
|
||||
processed_files.append({
|
||||
"path": file_path,
|
||||
"metadata": result.metadata,
|
||||
"source": result.source
|
||||
})
|
||||
processed_files.append(
|
||||
{
|
||||
"path": file_path,
|
||||
"metadata": result.metadata,
|
||||
"source": result.source,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing {file_path}: {str(e)}"
|
||||
error_msg = f"Error processing {file_path}: {e!s}"
|
||||
errors.append(error_msg)
|
||||
all_contents.append(f"=== File: {file_path} (ERROR) ===\n{error_msg}")
|
||||
|
||||
@@ -71,23 +76,29 @@ class DirectoryLoader(BaseLoader):
|
||||
"processed_files": len(processed_files),
|
||||
"errors": len(errors),
|
||||
"file_details": processed_files,
|
||||
"error_details": errors
|
||||
"error_details": errors,
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=combined_content,
|
||||
source=dir_path,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=dir_path, content=combined_content)
|
||||
doc_id=self.generate_doc_id(source_ref=dir_path, content=combined_content),
|
||||
)
|
||||
|
||||
def _find_files(self, dir_path: str, recursive: bool, include_ext: List[str] | None = None, exclude_ext: List[str] | None = None) -> List[str]:
|
||||
def _find_files(
|
||||
self,
|
||||
dir_path: str,
|
||||
recursive: bool,
|
||||
include_ext: list[str] | None = None,
|
||||
exclude_ext: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Find all files in directory matching criteria."""
|
||||
files = []
|
||||
|
||||
if recursive:
|
||||
for root, dirs, filenames in os.walk(dir_path):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.')]
|
||||
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if self._should_include_file(filename, include_ext, exclude_ext):
|
||||
@@ -96,26 +107,37 @@ class DirectoryLoader(BaseLoader):
|
||||
try:
|
||||
for item in os.listdir(dir_path):
|
||||
item_path = os.path.join(dir_path, item)
|
||||
if os.path.isfile(item_path) and self._should_include_file(item, include_ext, exclude_ext):
|
||||
if os.path.isfile(item_path) and self._should_include_file(
|
||||
item, include_ext, exclude_ext
|
||||
):
|
||||
files.append(item_path)
|
||||
except PermissionError:
|
||||
pass
|
||||
|
||||
return sorted(files)
|
||||
|
||||
def _should_include_file(self, filename: str, include_ext: List[str] = None, exclude_ext: List[str] = None) -> bool:
|
||||
def _should_include_file(
|
||||
self,
|
||||
filename: str,
|
||||
include_ext: list[str] | None = None,
|
||||
exclude_ext: list[str] | None = None,
|
||||
) -> bool:
|
||||
"""Determine if a file should be included based on criteria."""
|
||||
if filename.startswith('.'):
|
||||
if filename.startswith("."):
|
||||
return False
|
||||
|
||||
_, ext = os.path.splitext(filename.lower())
|
||||
|
||||
if include_ext:
|
||||
if ext not in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in include_ext]:
|
||||
if ext not in [
|
||||
e.lower() if e.startswith(".") else f".{e.lower()}" for e in include_ext
|
||||
]:
|
||||
return False
|
||||
|
||||
if exclude_ext:
|
||||
if ext in [e.lower() if e.startswith('.') else f'.{e.lower()}' for e in exclude_ext]:
|
||||
if ext in [
|
||||
e.lower() if e.startswith(".") else f".{e.lower()}" for e in exclude_ext
|
||||
]:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -132,11 +154,13 @@ class DirectoryLoader(BaseLoader):
|
||||
if result.metadata is None:
|
||||
result.metadata = {}
|
||||
|
||||
result.metadata.update({
|
||||
"file_path": file_path,
|
||||
"file_size": os.path.getsize(file_path),
|
||||
"data_type": str(data_type),
|
||||
"loader_type": loader.__class__.__name__
|
||||
})
|
||||
result.metadata.update(
|
||||
{
|
||||
"file_path": file_path,
|
||||
"file_size": os.path.getsize(file_path),
|
||||
"data_type": str(data_type),
|
||||
"loader_type": loader.__class__.__name__,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
106
packages/tools/src/crewai_tools/rag/loaders/docs_site_loader.py
Normal file
106
packages/tools/src/crewai_tools/rag/loaders/docs_site_loader.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Documentation site loader."""
|
||||
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class DocsSiteLoader(BaseLoader):
|
||||
"""Loader for documentation websites."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a documentation site.
|
||||
|
||||
Args:
|
||||
source: Documentation site URL
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
LoaderResult with documentation content
|
||||
"""
|
||||
docs_url = source.source
|
||||
|
||||
try:
|
||||
response = requests.get(docs_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"Unable to fetch documentation from {docs_url}: {e}")
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.decompose()
|
||||
|
||||
title = soup.find("title")
|
||||
title_text = title.get_text(strip=True) if title else "Documentation"
|
||||
|
||||
main_content = None
|
||||
for selector in [
|
||||
"main",
|
||||
"article",
|
||||
'[role="main"]',
|
||||
".content",
|
||||
"#content",
|
||||
".documentation",
|
||||
]:
|
||||
main_content = soup.select_one(selector)
|
||||
if main_content:
|
||||
break
|
||||
|
||||
if not main_content:
|
||||
main_content = soup.find("body")
|
||||
|
||||
if not main_content:
|
||||
raise ValueError(
|
||||
f"Unable to extract content from documentation site: {docs_url}"
|
||||
)
|
||||
|
||||
text_parts = [f"Title: {title_text}", ""]
|
||||
|
||||
headings = main_content.find_all(["h1", "h2", "h3"])
|
||||
if headings:
|
||||
text_parts.append("Table of Contents:")
|
||||
for heading in headings[:15]:
|
||||
level = int(heading.name[1])
|
||||
indent = " " * (level - 1)
|
||||
text_parts.append(f"{indent}- {heading.get_text(strip=True)}")
|
||||
text_parts.append("")
|
||||
|
||||
text = main_content.get_text(separator="\n", strip=True)
|
||||
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
||||
text_parts.extend(lines)
|
||||
|
||||
nav_links = []
|
||||
for nav_selector in ["nav", ".sidebar", ".toc", ".navigation"]:
|
||||
nav = soup.select_one(nav_selector)
|
||||
if nav:
|
||||
links = nav.find_all("a", href=True)
|
||||
for link in links[:20]:
|
||||
href = link["href"]
|
||||
if not href.startswith(("http://", "https://", "mailto:", "#")):
|
||||
full_url = urljoin(docs_url, href)
|
||||
nav_links.append(f"- {link.get_text(strip=True)}: {full_url}")
|
||||
|
||||
if nav_links:
|
||||
text_parts.append("")
|
||||
text_parts.append("Related documentation pages:")
|
||||
text_parts.extend(nav_links[:10])
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
if len(content) > 100000:
|
||||
content = content[:100000] + "\n\n[Content truncated...]"
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": docs_url,
|
||||
"title": title_text,
|
||||
"domain": urlparse(docs_url).netloc,
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=docs_url, content=content),
|
||||
)
|
||||
@@ -10,7 +10,9 @@ class DOCXLoader(BaseLoader):
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
except ImportError:
|
||||
raise ImportError("python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]")
|
||||
raise ImportError(
|
||||
"python-docx is required for DOCX loading. Install with: 'uv pip install python-docx' or pip install crewai-tools[rag]"
|
||||
)
|
||||
|
||||
source_ref = source_content.source_ref
|
||||
|
||||
@@ -23,28 +25,35 @@ class DOCXLoader(BaseLoader):
|
||||
elif source_content.path_exists():
|
||||
return self._load_from_file(source_ref, source_ref, DocxDocument)
|
||||
else:
|
||||
raise ValueError(f"Source must be a valid file path or URL, got: {source_content.source}")
|
||||
raise ValueError(
|
||||
f"Source must be a valid file path or URL, got: {source_content.source}"
|
||||
)
|
||||
|
||||
def _download_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools DOCXLoader)"
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools DOCXLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
# Create temporary file to save the DOCX content
|
||||
with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as temp_file:
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as temp_file:
|
||||
temp_file.write(response.content)
|
||||
return temp_file.name
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching DOCX from URL {url}: {str(e)}")
|
||||
raise ValueError(f"Error fetching DOCX from URL {url}: {e!s}")
|
||||
|
||||
def _load_from_file(self, file_path: str, source_ref: str, DocxDocument) -> LoaderResult:
|
||||
def _load_from_file(
|
||||
self, file_path: str, source_ref: str, DocxDocument
|
||||
) -> LoaderResult:
|
||||
try:
|
||||
doc = DocxDocument(file_path)
|
||||
|
||||
@@ -58,15 +67,15 @@ class DOCXLoader(BaseLoader):
|
||||
metadata = {
|
||||
"format": "docx",
|
||||
"paragraphs": len(doc.paragraphs),
|
||||
"tables": len(doc.tables)
|
||||
"tables": len(doc.tables),
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=content)
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=content),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading DOCX file: {str(e)}")
|
||||
raise ValueError(f"Error loading DOCX file: {e!s}")
|
||||
|
||||
112
packages/tools/src/crewai_tools/rag/loaders/github_loader.py
Normal file
112
packages/tools/src/crewai_tools/rag/loaders/github_loader.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""GitHub repository content loader."""
|
||||
|
||||
from github import Github, GithubException
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class GithubLoader(BaseLoader):
|
||||
"""Loader for GitHub repository content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a GitHub repository.
|
||||
|
||||
Args:
|
||||
source: GitHub repository URL
|
||||
**kwargs: Additional arguments including gh_token and content_types
|
||||
|
||||
Returns:
|
||||
LoaderResult with repository content
|
||||
"""
|
||||
metadata = kwargs.get("metadata", {})
|
||||
gh_token = metadata.get("gh_token")
|
||||
content_types = metadata.get("content_types", ["code", "repo"])
|
||||
|
||||
repo_url = source.source
|
||||
if not repo_url.startswith("https://github.com/"):
|
||||
raise ValueError(f"Invalid GitHub URL: {repo_url}")
|
||||
|
||||
parts = repo_url.replace("https://github.com/", "").strip("/").split("/")
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid GitHub repository URL: {repo_url}")
|
||||
|
||||
repo_name = f"{parts[0]}/{parts[1]}"
|
||||
|
||||
g = Github(gh_token) if gh_token else Github()
|
||||
|
||||
try:
|
||||
repo = g.get_repo(repo_name)
|
||||
except GithubException as e:
|
||||
raise ValueError(f"Unable to access repository {repo_name}: {e}")
|
||||
|
||||
all_content = []
|
||||
|
||||
if "repo" in content_types:
|
||||
all_content.append(f"Repository: {repo.full_name}")
|
||||
all_content.append(f"Description: {repo.description or 'No description'}")
|
||||
all_content.append(f"Language: {repo.language or 'Not specified'}")
|
||||
all_content.append(f"Stars: {repo.stargazers_count}")
|
||||
all_content.append(f"Forks: {repo.forks_count}")
|
||||
all_content.append("")
|
||||
|
||||
if "code" in content_types:
|
||||
try:
|
||||
readme = repo.get_readme()
|
||||
all_content.append("README:")
|
||||
all_content.append(
|
||||
readme.decoded_content.decode("utf-8", errors="ignore")
|
||||
)
|
||||
all_content.append("")
|
||||
except GithubException:
|
||||
pass
|
||||
|
||||
try:
|
||||
contents = repo.get_contents("")
|
||||
if isinstance(contents, list):
|
||||
all_content.append("Repository structure:")
|
||||
for content_file in contents[:20]:
|
||||
all_content.append(
|
||||
f"- {content_file.path} ({content_file.type})"
|
||||
)
|
||||
all_content.append("")
|
||||
except GithubException:
|
||||
pass
|
||||
|
||||
if "pr" in content_types:
|
||||
prs = repo.get_pulls(state="open")
|
||||
pr_list = list(prs[:5])
|
||||
if pr_list:
|
||||
all_content.append("Recent Pull Requests:")
|
||||
for pr in pr_list:
|
||||
all_content.append(f"- PR #{pr.number}: {pr.title}")
|
||||
if pr.body:
|
||||
body_preview = pr.body[:200].replace("\n", " ")
|
||||
all_content.append(f" {body_preview}")
|
||||
all_content.append("")
|
||||
|
||||
if "issue" in content_types:
|
||||
issues = repo.get_issues(state="open")
|
||||
issue_list = [i for i in list(issues[:10]) if not i.pull_request][:5]
|
||||
if issue_list:
|
||||
all_content.append("Recent Issues:")
|
||||
for issue in issue_list:
|
||||
all_content.append(f"- Issue #{issue.number}: {issue.title}")
|
||||
if issue.body:
|
||||
body_preview = issue.body[:200].replace("\n", " ")
|
||||
all_content.append(f" {body_preview}")
|
||||
all_content.append("")
|
||||
|
||||
if not all_content:
|
||||
raise ValueError(f"No content could be loaded from repository: {repo_url}")
|
||||
|
||||
content = "\n".join(all_content)
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": repo_url,
|
||||
"repo": repo_name,
|
||||
"content_types": content_types,
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=repo_url, content=content),
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class JSONLoader(BaseLoader):
|
||||
@@ -19,17 +19,24 @@ class JSONLoader(BaseLoader):
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)"
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text if not self._is_json_response(response) else json.dumps(response.json(), indent=2)
|
||||
return (
|
||||
response.text
|
||||
if not self._is_json_response(response)
|
||||
else json.dumps(response.json(), indent=2)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching JSON from URL {url}: {str(e)}")
|
||||
raise ValueError(f"Error fetching JSON from URL {url}: {e!s}")
|
||||
|
||||
def _is_json_response(self, response) -> bool:
|
||||
try:
|
||||
@@ -46,7 +53,9 @@ class JSONLoader(BaseLoader):
|
||||
try:
|
||||
data = json.loads(content)
|
||||
if isinstance(data, dict):
|
||||
text = "\n".join(f"{k}: {json.dumps(v, indent=0)}" for k, v in data.items())
|
||||
text = "\n".join(
|
||||
f"{k}: {json.dumps(v, indent=0)}" for k, v in data.items()
|
||||
)
|
||||
elif isinstance(data, list):
|
||||
text = "\n".join(json.dumps(item, indent=0) for item in data)
|
||||
else:
|
||||
@@ -55,7 +64,7 @@ class JSONLoader(BaseLoader):
|
||||
metadata = {
|
||||
"format": "json",
|
||||
"type": type(data).__name__,
|
||||
"size": len(data) if isinstance(data, (list, dict)) else 1
|
||||
"size": len(data) if isinstance(data, (list, dict)) else 1,
|
||||
}
|
||||
except json.JSONDecodeError as e:
|
||||
text = content
|
||||
@@ -65,5 +74,5 @@ class JSONLoader(BaseLoader):
|
||||
content=text,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import re
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class MDXLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
@@ -18,17 +19,20 @@ class MDXLoader(BaseLoader):
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "text/markdown, text/x-markdown, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)"
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "text/markdown, text/x-markdown, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching MDX from URL {url}: {str(e)}")
|
||||
raise ValueError(f"Error fetching MDX from URL {url}: {e!s}")
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
@@ -38,16 +42,20 @@ class MDXLoader(BaseLoader):
|
||||
cleaned_content = content
|
||||
|
||||
# Remove import statements
|
||||
cleaned_content = re.sub(r'^import\s+.*?\n', '', cleaned_content, flags=re.MULTILINE)
|
||||
cleaned_content = re.sub(
|
||||
r"^import\s+.*?\n", "", cleaned_content, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
# Remove export statements
|
||||
cleaned_content = re.sub(r'^export\s+.*?(?:\n|$)', '', cleaned_content, flags=re.MULTILINE)
|
||||
cleaned_content = re.sub(
|
||||
r"^export\s+.*?(?:\n|$)", "", cleaned_content, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
# Remove JSX tags (simple approach)
|
||||
cleaned_content = re.sub(r'<[^>]+>', '', cleaned_content)
|
||||
cleaned_content = re.sub(r"<[^>]+>", "", cleaned_content)
|
||||
|
||||
# Clean up extra whitespace
|
||||
cleaned_content = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_content)
|
||||
cleaned_content = re.sub(r"\n\s*\n\s*\n", "\n\n", cleaned_content)
|
||||
cleaned_content = cleaned_content.strip()
|
||||
|
||||
metadata = {"format": "mdx"}
|
||||
@@ -55,5 +63,5 @@ class MDXLoader(BaseLoader):
|
||||
content=cleaned_content,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=cleaned_content)
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=cleaned_content),
|
||||
)
|
||||
|
||||
100
packages/tools/src/crewai_tools/rag/loaders/mysql_loader.py
Normal file
100
packages/tools/src/crewai_tools/rag/loaders/mysql_loader.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""MySQL database loader."""
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pymysql
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class MySQLLoader(BaseLoader):
|
||||
"""Loader for MySQL database content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a MySQL database table.
|
||||
|
||||
Args:
|
||||
source: SQL query (e.g., "SELECT * FROM table_name")
|
||||
**kwargs: Additional arguments including db_uri
|
||||
|
||||
Returns:
|
||||
LoaderResult with database content
|
||||
"""
|
||||
metadata = kwargs.get("metadata", {})
|
||||
db_uri = metadata.get("db_uri")
|
||||
|
||||
if not db_uri:
|
||||
raise ValueError("Database URI is required for MySQL loader")
|
||||
|
||||
query = source.source
|
||||
|
||||
parsed = urlparse(db_uri)
|
||||
if parsed.scheme not in ["mysql", "mysql+pymysql"]:
|
||||
raise ValueError(f"Invalid MySQL URI scheme: {parsed.scheme}")
|
||||
|
||||
connection_params = {
|
||||
"host": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 3306,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||
"charset": "utf8mb4",
|
||||
"cursorclass": pymysql.cursors.DictCursor,
|
||||
}
|
||||
|
||||
if not connection_params["database"]:
|
||||
raise ValueError("Database name is required in the URI")
|
||||
|
||||
try:
|
||||
connection = pymysql.connect(**connection_params)
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
content = "No data found in the table"
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={"source": query, "row_count": 0},
|
||||
doc_id=self.generate_doc_id(
|
||||
source_ref=query, content=content
|
||||
),
|
||||
)
|
||||
|
||||
text_parts = []
|
||||
|
||||
columns = list(rows[0].keys())
|
||||
text_parts.append(f"Columns: {', '.join(columns)}")
|
||||
text_parts.append(f"Total rows: {len(rows)}")
|
||||
text_parts.append("")
|
||||
|
||||
for i, row in enumerate(rows, 1):
|
||||
text_parts.append(f"Row {i}:")
|
||||
for col, val in row.items():
|
||||
if val is not None:
|
||||
text_parts.append(f" {col}: {val}")
|
||||
text_parts.append("")
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
if len(content) > 100000:
|
||||
content = content[:100000] + "\n\n[Content truncated...]"
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": query,
|
||||
"database": connection_params["database"],
|
||||
"row_count": len(rows),
|
||||
"columns": columns,
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=query, content=content),
|
||||
)
|
||||
finally:
|
||||
connection.close()
|
||||
except pymysql.Error as e:
|
||||
raise ValueError(f"MySQL database error: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data from MySQL: {e}")
|
||||
71
packages/tools/src/crewai_tools/rag/loaders/pdf_loader.py
Normal file
71
packages/tools/src/crewai_tools/rag/loaders/pdf_loader.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""PDF loader for extracting text from PDF files."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PDFLoader(BaseLoader):
|
||||
"""Loader for PDF files."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract text from a PDF file.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist
|
||||
ImportError: If required PDF libraries aren't installed
|
||||
"""
|
||||
try:
|
||||
import pypdf
|
||||
except ImportError:
|
||||
try:
|
||||
import PyPDF2 as pypdf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
|
||||
)
|
||||
|
||||
file_path = source.source
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
|
||||
text_content = []
|
||||
metadata: dict[str, Any] = {
|
||||
"source": str(file_path),
|
||||
"file_name": Path(file_path).name,
|
||||
"file_type": "pdf",
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
pdf_reader = pypdf.PdfReader(file)
|
||||
metadata["num_pages"] = len(pdf_reader.pages)
|
||||
|
||||
for page_num, page in enumerate(pdf_reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading PDF file {file_path}: {e!s}")
|
||||
|
||||
if not text_content:
|
||||
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||
else:
|
||||
content = "\n\n".join(text_content)
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=str(file_path),
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content),
|
||||
)
|
||||
100
packages/tools/src/crewai_tools/rag/loaders/postgres_loader.py
Normal file
100
packages/tools/src/crewai_tools/rag/loaders/postgres_loader.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""PostgreSQL database loader."""
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PostgresLoader(BaseLoader):
|
||||
"""Loader for PostgreSQL database content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a PostgreSQL database table.
|
||||
|
||||
Args:
|
||||
source: SQL query (e.g., "SELECT * FROM table_name")
|
||||
**kwargs: Additional arguments including db_uri
|
||||
|
||||
Returns:
|
||||
LoaderResult with database content
|
||||
"""
|
||||
metadata = kwargs.get("metadata", {})
|
||||
db_uri = metadata.get("db_uri")
|
||||
|
||||
if not db_uri:
|
||||
raise ValueError("Database URI is required for PostgreSQL loader")
|
||||
|
||||
query = source.source
|
||||
|
||||
parsed = urlparse(db_uri)
|
||||
if parsed.scheme not in ["postgresql", "postgres", "postgresql+psycopg2"]:
|
||||
raise ValueError(f"Invalid PostgreSQL URI scheme: {parsed.scheme}")
|
||||
|
||||
connection_params = {
|
||||
"host": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 5432,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||
"cursor_factory": RealDictCursor,
|
||||
}
|
||||
|
||||
if not connection_params["database"]:
|
||||
raise ValueError("Database name is required in the URI")
|
||||
|
||||
try:
|
||||
connection = psycopg2.connect(**connection_params)
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
content = "No data found in the table"
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={"source": query, "row_count": 0},
|
||||
doc_id=self.generate_doc_id(
|
||||
source_ref=query, content=content
|
||||
),
|
||||
)
|
||||
|
||||
text_parts = []
|
||||
|
||||
columns = list(rows[0].keys())
|
||||
text_parts.append(f"Columns: {', '.join(columns)}")
|
||||
text_parts.append(f"Total rows: {len(rows)}")
|
||||
text_parts.append("")
|
||||
|
||||
for i, row in enumerate(rows, 1):
|
||||
text_parts.append(f"Row {i}:")
|
||||
for col, val in row.items():
|
||||
if val is not None:
|
||||
text_parts.append(f" {col}: {val}")
|
||||
text_parts.append("")
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
if len(content) > 100000:
|
||||
content = content[:100000] + "\n\n[Content truncated...]"
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": query,
|
||||
"database": connection_params["database"],
|
||||
"row_count": len(rows),
|
||||
"columns": columns,
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=query, content=content),
|
||||
)
|
||||
finally:
|
||||
connection.close()
|
||||
except psycopg2.Error as e:
|
||||
raise ValueError(f"PostgreSQL database error: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data from PostgreSQL: {e}")
|
||||
@@ -1,18 +1,23 @@
|
||||
import re
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class WebPageLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
url = source_content.source
|
||||
headers = kwargs.get("headers", {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=15, headers=headers)
|
||||
@@ -28,20 +33,22 @@ class WebPageLoader(BaseLoader):
|
||||
text = re.sub("\\s+\n\\s+", "\n", text)
|
||||
text = text.strip()
|
||||
|
||||
title = soup.title.string.strip() if soup.title and soup.title.string else ""
|
||||
title = (
|
||||
soup.title.string.strip() if soup.title and soup.title.string else ""
|
||||
)
|
||||
metadata = {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"status_code": response.status_code,
|
||||
"content_type": response.headers.get("content-type", "")
|
||||
"content_type": response.headers.get("content-type", ""),
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=text,
|
||||
source=url,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=url, content=text)
|
||||
doc_id=self.generate_doc_id(source_ref=url, content=text),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading webpage {url}: {str(e)}")
|
||||
raise ValueError(f"Error loading webpage {url}: {e!s}")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class XMLLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
@@ -11,7 +11,7 @@ class XMLLoader(BaseLoader):
|
||||
|
||||
if source_content.is_url():
|
||||
content = self._load_from_url(source_ref, kwargs)
|
||||
elif os.path.exists(source_ref):
|
||||
elif source_content.path_exists():
|
||||
content = self._load_from_file(source_ref)
|
||||
|
||||
return self._parse_xml(content, source_ref)
|
||||
@@ -19,17 +19,20 @@ class XMLLoader(BaseLoader):
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get("headers", {
|
||||
"Accept": "application/xml, text/xml, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)"
|
||||
})
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "application/xml, text/xml, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching XML from URL {url}: {str(e)}")
|
||||
raise ValueError(f"Error fetching XML from URL {url}: {e!s}")
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
@@ -37,7 +40,7 @@ class XMLLoader(BaseLoader):
|
||||
|
||||
def _parse_xml(self, content: str, source_ref: str) -> LoaderResult:
|
||||
try:
|
||||
if content.strip().startswith('<'):
|
||||
if content.strip().startswith("<"):
|
||||
root = ET.fromstring(content)
|
||||
else:
|
||||
root = ET.parse(source_ref).getroot()
|
||||
@@ -57,5 +60,5 @@ class XMLLoader(BaseLoader):
|
||||
content=text,
|
||||
source=source_ref,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text)
|
||||
doc_id=self.generate_doc_id(source_ref=source_ref, content=text),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
"""YouTube channel loader for extracting content from YouTube channels."""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class YoutubeChannelLoader(BaseLoader):
|
||||
"""Loader for YouTube channels."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract content from a YouTube channel.
|
||||
|
||||
Args:
|
||||
source: The source content containing the YouTube channel URL
|
||||
|
||||
Returns:
|
||||
LoaderResult with channel content
|
||||
|
||||
Raises:
|
||||
ImportError: If required YouTube libraries aren't installed
|
||||
ValueError: If the URL is not a valid YouTube channel URL
|
||||
"""
|
||||
try:
|
||||
from pytube import Channel
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"YouTube channel support requires pytube. Install with: uv add pytube"
|
||||
)
|
||||
|
||||
channel_url = source.source
|
||||
|
||||
if not any(
|
||||
pattern in channel_url
|
||||
for pattern in [
|
||||
"youtube.com/channel/",
|
||||
"youtube.com/c/",
|
||||
"youtube.com/@",
|
||||
"youtube.com/user/",
|
||||
]
|
||||
):
|
||||
raise ValueError(f"Invalid YouTube channel URL: {channel_url}")
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"source": channel_url,
|
||||
"data_type": "youtube_channel",
|
||||
}
|
||||
|
||||
try:
|
||||
channel = Channel(channel_url)
|
||||
|
||||
metadata["channel_name"] = channel.channel_name
|
||||
metadata["channel_id"] = channel.channel_id
|
||||
|
||||
max_videos = kwargs.get("max_videos", 10)
|
||||
video_urls = list(channel.video_urls)[:max_videos]
|
||||
metadata["num_videos_loaded"] = len(video_urls)
|
||||
metadata["total_videos"] = len(list(channel.video_urls))
|
||||
|
||||
content_parts = [
|
||||
f"YouTube Channel: {channel.channel_name}",
|
||||
f"Channel ID: {channel.channel_id}",
|
||||
f"Total Videos: {metadata['total_videos']}",
|
||||
f"Videos Loaded: {metadata['num_videos_loaded']}",
|
||||
"\n--- Video Summaries ---\n",
|
||||
]
|
||||
|
||||
try:
|
||||
from pytube import YouTube
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
for i, video_url in enumerate(video_urls, 1):
|
||||
try:
|
||||
video_id = self._extract_video_id(video_url)
|
||||
if not video_id:
|
||||
continue
|
||||
yt = YouTube(video_url)
|
||||
title = yt.title or f"Video {i}"
|
||||
description = (
|
||||
yt.description[:200] if yt.description else "No description"
|
||||
)
|
||||
|
||||
content_parts.append(f"\n{i}. {title}")
|
||||
content_parts.append(f" URL: {video_url}")
|
||||
content_parts.append(f" Description: {description}...")
|
||||
|
||||
try:
|
||||
api = YouTubeTranscriptApi()
|
||||
transcript_list = api.list(video_id)
|
||||
transcript = None
|
||||
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(["en"])
|
||||
except:
|
||||
try:
|
||||
transcript = (
|
||||
transcript_list.find_generated_transcript(
|
||||
["en"]
|
||||
)
|
||||
)
|
||||
except:
|
||||
transcript = next(iter(transcript_list), None)
|
||||
|
||||
if transcript:
|
||||
transcript_data = transcript.fetch()
|
||||
text_parts = []
|
||||
char_count = 0
|
||||
for entry in transcript_data:
|
||||
text = (
|
||||
entry.text.strip()
|
||||
if hasattr(entry, "text")
|
||||
else ""
|
||||
)
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
char_count += len(text)
|
||||
if char_count > 500:
|
||||
break
|
||||
|
||||
if text_parts:
|
||||
preview = " ".join(text_parts)[:500]
|
||||
content_parts.append(
|
||||
f" Transcript Preview: {preview}..."
|
||||
)
|
||||
except:
|
||||
content_parts.append(" Transcript: Not available")
|
||||
|
||||
except Exception as e:
|
||||
content_parts.append(f"\n{i}. Error loading video: {e!s}")
|
||||
|
||||
except ImportError:
|
||||
for i, video_url in enumerate(video_urls, 1):
|
||||
content_parts.append(f"\n{i}. {video_url}")
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Unable to load YouTube channel {channel_url}: {e!s}"
|
||||
) from e
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=channel_url,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=channel_url, content=content),
|
||||
)
|
||||
|
||||
def _extract_video_id(self, url: str) -> str | None:
|
||||
"""Extract video ID from YouTube URL."""
|
||||
patterns = [
|
||||
r"(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,134 @@
|
||||
"""YouTube video loader for extracting transcripts from YouTube videos."""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class YoutubeVideoLoader(BaseLoader):
|
||||
"""Loader for YouTube videos."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract transcript from a YouTube video.
|
||||
|
||||
Args:
|
||||
source: The source content containing the YouTube URL
|
||||
|
||||
Returns:
|
||||
LoaderResult with transcript content
|
||||
|
||||
Raises:
|
||||
ImportError: If required YouTube libraries aren't installed
|
||||
ValueError: If the URL is not a valid YouTube video URL
|
||||
"""
|
||||
try:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"YouTube support requires youtube-transcript-api. "
|
||||
"Install with: uv add youtube-transcript-api"
|
||||
)
|
||||
|
||||
video_url = source.source
|
||||
video_id = self._extract_video_id(video_url)
|
||||
|
||||
if not video_id:
|
||||
raise ValueError(f"Invalid YouTube URL: {video_url}")
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"source": video_url,
|
||||
"video_id": video_id,
|
||||
"data_type": "youtube_video",
|
||||
}
|
||||
|
||||
try:
|
||||
api = YouTubeTranscriptApi()
|
||||
transcript_list = api.list(video_id)
|
||||
|
||||
transcript = None
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(["en"])
|
||||
except:
|
||||
try:
|
||||
transcript = transcript_list.find_generated_transcript(["en"])
|
||||
except:
|
||||
transcript = next(iter(transcript_list))
|
||||
|
||||
if transcript:
|
||||
metadata["language"] = transcript.language
|
||||
metadata["is_generated"] = transcript.is_generated
|
||||
|
||||
transcript_data = transcript.fetch()
|
||||
|
||||
text_content = []
|
||||
for entry in transcript_data:
|
||||
text = entry.text.strip() if hasattr(entry, "text") else ""
|
||||
if text:
|
||||
text_content.append(text)
|
||||
|
||||
content = " ".join(text_content)
|
||||
|
||||
try:
|
||||
from pytube import YouTube
|
||||
|
||||
yt = YouTube(video_url)
|
||||
metadata["title"] = yt.title
|
||||
metadata["author"] = yt.author
|
||||
metadata["length_seconds"] = yt.length
|
||||
metadata["description"] = (
|
||||
yt.description[:500] if yt.description else None
|
||||
)
|
||||
|
||||
if yt.title:
|
||||
content = f"Title: {yt.title}\n\nAuthor: {yt.author or 'Unknown'}\n\nTranscript:\n{content}"
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No transcript available for YouTube video: {video_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Unable to extract transcript from YouTube video {video_id}: {e!s}"
|
||||
) from e
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=video_url,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=video_url, content=content),
|
||||
)
|
||||
|
||||
def _extract_video_id(self, url: str) -> str | None:
|
||||
"""Extract video ID from various YouTube URL formats."""
|
||||
patterns = [
|
||||
r"(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname
|
||||
if hostname:
|
||||
hostname_lower = hostname.lower()
|
||||
# Allow youtube.com and any subdomain of youtube.com, plus youtu.be shortener
|
||||
if (
|
||||
hostname_lower == "youtube.com"
|
||||
or hostname_lower.endswith(".youtube.com")
|
||||
or hostname_lower == "youtu.be"
|
||||
):
|
||||
query_params = parse_qs(parsed.query)
|
||||
if "v" in query_params:
|
||||
return query_params["v"][0]
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -1,4 +1,31 @@
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
|
||||
def compute_sha256(content: str) -> str:
|
||||
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def sanitize_metadata_for_chromadb(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sanitize metadata to ensure ChromaDB compatibility.
|
||||
|
||||
ChromaDB only accepts str, int, float, or bool values in metadata.
|
||||
This function converts other types to strings.
|
||||
|
||||
Args:
|
||||
metadata: Dictionary of metadata to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized metadata dictionary with only ChromaDB-compatible types
|
||||
"""
|
||||
sanitized = {}
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
sanitized[key] = value
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Convert lists/tuples to pipe-separated strings
|
||||
sanitized[key] = " | ".join(str(v) for v in value)
|
||||
else:
|
||||
# Convert other types to string
|
||||
sanitized[key] = str(value)
|
||||
return sanitized
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from typing import TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai_tools.rag.misc import compute_sha256
|
||||
|
||||
@@ -34,7 +34,7 @@ class SourceContent:
|
||||
|
||||
@cached_property
|
||||
def source_ref(self) -> str:
|
||||
""""
|
||||
""" "
|
||||
Returns the source reference for the content.
|
||||
If the content is a URL or a local file, returns the source.
|
||||
Otherwise, returns the hash of the content.
|
||||
|
||||
Reference in New Issue
Block a user