mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: replace embedchain with native crewai adapter (#451)
- Remove embedchain adapter; add crewai rag adapter and update all search tools - Add loaders: pdf, youtube (video & channel), github, docs site, mysql, postgresql - Add configurable similarity threshold, limit params, and embedding_model support - Improve chromadb compatibility (sanitize metadata, convert columns, fix chunking) - Fix xml encoding, Python 3.10 issues, and youtube url spoofing - Update crewai dependency and instructions; refresh uv.lock - Update tests for new rag adapter and search params
This commit is contained in:
215
src/crewai_tools/adapters/crewai_rag_adapter.py
Normal file
215
src/crewai_tools/adapters/crewai_rag_adapter.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
"""Adapter for CrewAI's native RAG system."""
|
||||||
|
|
||||||
|
from typing import Any, TypedDict, TypeAlias
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
from pathlib import Path
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from pydantic import Field, PrivateAttr
|
||||||
|
from crewai.rag.config.utils import get_rag_client
|
||||||
|
from crewai.rag.config.types import RagConfigType
|
||||||
|
from crewai.rag.types import BaseRecord, SearchResult
|
||||||
|
from crewai.rag.core.base_client import BaseClient
|
||||||
|
from crewai.rag.factory import create_client
|
||||||
|
|
||||||
|
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||||
|
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||||
|
|
||||||
|
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||||
|
|
||||||
|
class AddDocumentParams(TypedDict, total=False):
|
||||||
|
"""Parameters for adding documents to the RAG system."""
|
||||||
|
data_type: DataType
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
website: str
|
||||||
|
url: str
|
||||||
|
file_path: str | Path
|
||||||
|
github_url: str
|
||||||
|
youtube_url: str
|
||||||
|
directory_path: str | Path
|
||||||
|
|
||||||
|
|
||||||
|
class CrewAIRagAdapter(Adapter):
|
||||||
|
"""Adapter that uses CrewAI's native RAG system.
|
||||||
|
|
||||||
|
Supports custom vector database configuration through the config parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
collection_name: str = "default"
|
||||||
|
summarize: bool = False
|
||||||
|
similarity_threshold: float = 0.6
|
||||||
|
limit: int = 5
|
||||||
|
config: RagConfigType | None = None
|
||||||
|
_client: BaseClient | None = PrivateAttr(default=None)
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
"""Initialize the CrewAI RAG client after model initialization."""
|
||||||
|
if self.config is not None:
|
||||||
|
self._client = create_client(self.config)
|
||||||
|
else:
|
||||||
|
self._client = get_rag_client()
|
||||||
|
self._client.get_or_create_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
|
def query(self, question: str, similarity_threshold: float | None = None, limit: int | None = None) -> str:
|
||||||
|
"""Query the knowledge base with a question.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: The question to ask
|
||||||
|
similarity_threshold: Minimum similarity score for results (default: 0.6)
|
||||||
|
limit: Maximum number of results to return (default: 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relevant content from the knowledge base
|
||||||
|
"""
|
||||||
|
search_limit = limit if limit is not None else self.limit
|
||||||
|
search_threshold = similarity_threshold if similarity_threshold is not None else self.similarity_threshold
|
||||||
|
|
||||||
|
results: list[SearchResult] = self._client.search(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
query=question,
|
||||||
|
limit=search_limit,
|
||||||
|
score_threshold=search_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return "No relevant content found."
|
||||||
|
|
||||||
|
contents: list[str] = []
|
||||||
|
for result in results:
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if content:
|
||||||
|
contents.append(content)
|
||||||
|
|
||||||
|
return "\n\n".join(contents)
|
||||||
|
|
||||||
|
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
||||||
|
"""Add content to the knowledge base.
|
||||||
|
|
||||||
|
This method handles various input types and converts them to documents
|
||||||
|
for the vector database. It supports the data_type parameter for
|
||||||
|
compatibility with existing tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Content items to add (strings, paths, or document dicts)
|
||||||
|
**kwargs: Additional parameters including data_type, metadata, etc.
|
||||||
|
"""
|
||||||
|
from crewai_tools.rag.data_types import DataTypes, DataType
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
from crewai_tools.rag.base_loader import LoaderResult
|
||||||
|
import os
|
||||||
|
|
||||||
|
documents: list[BaseRecord] = []
|
||||||
|
data_type: DataType | None = kwargs.get("data_type")
|
||||||
|
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
||||||
|
|
||||||
|
for arg in args:
|
||||||
|
source_ref: str
|
||||||
|
if isinstance(arg, dict):
|
||||||
|
source_ref = str(arg.get("source", arg.get("content", "")))
|
||||||
|
else:
|
||||||
|
source_ref = str(arg)
|
||||||
|
|
||||||
|
if not data_type:
|
||||||
|
data_type = DataTypes.from_content(source_ref)
|
||||||
|
|
||||||
|
if data_type == DataType.DIRECTORY:
|
||||||
|
if not os.path.isdir(source_ref):
|
||||||
|
raise ValueError(f"Directory does not exist: {source_ref}")
|
||||||
|
|
||||||
|
# Define binary and non-text file extensions to skip
|
||||||
|
binary_extensions = {'.pyc', '.pyo', '.png', '.jpg', '.jpeg', '.gif',
|
||||||
|
'.bmp', '.ico', '.svg', '.webp', '.pdf', '.zip',
|
||||||
|
'.tar', '.gz', '.bz2', '.7z', '.rar', '.exe',
|
||||||
|
'.dll', '.so', '.dylib', '.bin', '.dat', '.db',
|
||||||
|
'.sqlite', '.class', '.jar', '.war', '.ear'}
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(source_ref):
|
||||||
|
dirs[:] = [d for d in dirs if not d.startswith('.')]
|
||||||
|
|
||||||
|
for filename in files:
|
||||||
|
if filename.startswith('.'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip binary files based on extension
|
||||||
|
file_ext = os.path.splitext(filename)[1].lower()
|
||||||
|
if file_ext in binary_extensions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip __pycache__ directories
|
||||||
|
if '__pycache__' in root:
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_path: str = os.path.join(root, filename)
|
||||||
|
try:
|
||||||
|
file_data_type: DataType = DataTypes.from_content(file_path)
|
||||||
|
file_loader = file_data_type.get_loader()
|
||||||
|
file_chunker = file_data_type.get_chunker()
|
||||||
|
|
||||||
|
file_source = SourceContent(file_path)
|
||||||
|
file_result: LoaderResult = file_loader.load(file_source)
|
||||||
|
|
||||||
|
file_chunks = file_chunker.chunk(file_result.content)
|
||||||
|
|
||||||
|
for chunk_idx, file_chunk in enumerate(file_chunks):
|
||||||
|
file_metadata: dict[str, Any] = base_metadata.copy()
|
||||||
|
file_metadata.update(file_result.metadata)
|
||||||
|
file_metadata["data_type"] = str(file_data_type)
|
||||||
|
file_metadata["file_path"] = file_path
|
||||||
|
file_metadata["chunk_index"] = chunk_idx
|
||||||
|
file_metadata["total_chunks"] = len(file_chunks)
|
||||||
|
|
||||||
|
if isinstance(arg, dict):
|
||||||
|
file_metadata.update(arg.get("metadata", {}))
|
||||||
|
|
||||||
|
chunk_id = hashlib.sha256(f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()).hexdigest()
|
||||||
|
|
||||||
|
documents.append({
|
||||||
|
"doc_id": chunk_id,
|
||||||
|
"content": file_chunk,
|
||||||
|
"metadata": sanitize_metadata_for_chromadb(file_metadata)
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
# Silently skip files that can't be processed
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
metadata: dict[str, Any] = base_metadata.copy()
|
||||||
|
|
||||||
|
if data_type in [DataType.PDF_FILE, DataType.TEXT_FILE, DataType.DOCX,
|
||||||
|
DataType.CSV, DataType.JSON, DataType.XML, DataType.MDX]:
|
||||||
|
if not os.path.isfile(source_ref):
|
||||||
|
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||||
|
|
||||||
|
loader = data_type.get_loader()
|
||||||
|
chunker = data_type.get_chunker()
|
||||||
|
|
||||||
|
source_content = SourceContent(source_ref)
|
||||||
|
loader_result: LoaderResult = loader.load(source_content)
|
||||||
|
|
||||||
|
chunks = chunker.chunk(loader_result.content)
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk_metadata: dict[str, Any] = metadata.copy()
|
||||||
|
chunk_metadata.update(loader_result.metadata)
|
||||||
|
chunk_metadata["data_type"] = str(data_type)
|
||||||
|
chunk_metadata["chunk_index"] = i
|
||||||
|
chunk_metadata["total_chunks"] = len(chunks)
|
||||||
|
chunk_metadata["source"] = source_ref
|
||||||
|
|
||||||
|
if isinstance(arg, dict):
|
||||||
|
chunk_metadata.update(arg.get("metadata", {}))
|
||||||
|
|
||||||
|
chunk_id = hashlib.sha256(f"{loader_result.doc_id}_{i}_{chunk}".encode()).hexdigest()
|
||||||
|
|
||||||
|
documents.append({
|
||||||
|
"doc_id": chunk_id,
|
||||||
|
"content": chunk,
|
||||||
|
"metadata": sanitize_metadata_for_chromadb(chunk_metadata)
|
||||||
|
})
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
self._client.add_documents(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
documents=documents
|
||||||
|
)
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain import App
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
class EmbedchainAdapter(Adapter):
|
|
||||||
embedchain_app: Any # Will be App when embedchain is available
|
|
||||||
summarize: bool = False
|
|
||||||
|
|
||||||
def __init__(self, **data):
|
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().__init__(**data)
|
|
||||||
|
|
||||||
def query(self, question: str) -> str:
|
|
||||||
result, sources = self.embedchain_app.query(
|
|
||||||
question, citations=True, dry_run=(not self.summarize)
|
|
||||||
)
|
|
||||||
if self.summarize:
|
|
||||||
return result
|
|
||||||
return "\n\n".join([source[0] for source in sources])
|
|
||||||
|
|
||||||
def add(
|
|
||||||
self,
|
|
||||||
*args: Any,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self.embedchain_app.add(*args, **kwargs)
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain import App
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
class PDFEmbedchainAdapter(Adapter):
|
|
||||||
embedchain_app: Any # Will be App when embedchain is available
|
|
||||||
summarize: bool = False
|
|
||||||
src: Optional[str] = None
|
|
||||||
|
|
||||||
def __init__(self, **data):
|
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().__init__(**data)
|
|
||||||
|
|
||||||
def query(self, question: str) -> str:
|
|
||||||
where = (
|
|
||||||
{"app_id": self.embedchain_app.config.id, "source": self.src}
|
|
||||||
if self.src
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
result, sources = self.embedchain_app.query(
|
|
||||||
question, citations=True, dry_run=(not self.summarize), where=where
|
|
||||||
)
|
|
||||||
if self.summarize:
|
|
||||||
return result
|
|
||||||
return "\n\n".join([source[0] for source in sources])
|
|
||||||
|
|
||||||
def add(
|
|
||||||
self,
|
|
||||||
*args: Any,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self.src = args[0] if args else None
|
|
||||||
self.embedchain_app.add(*args, **kwargs)
|
|
||||||
@@ -112,7 +112,10 @@ class RecursiveCharacterTextSplitter:
|
|||||||
if separator == "":
|
if separator == "":
|
||||||
doc = "".join(current_doc)
|
doc = "".join(current_doc)
|
||||||
else:
|
else:
|
||||||
doc = separator.join(current_doc)
|
if self._keep_separator and separator == " ":
|
||||||
|
doc = "".join(current_doc)
|
||||||
|
else:
|
||||||
|
doc = separator.join(current_doc)
|
||||||
|
|
||||||
if doc:
|
if doc:
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
@@ -133,7 +136,10 @@ class RecursiveCharacterTextSplitter:
|
|||||||
if separator == "":
|
if separator == "":
|
||||||
doc = "".join(current_doc)
|
doc = "".join(current_doc)
|
||||||
else:
|
else:
|
||||||
doc = separator.join(current_doc)
|
if self._keep_separator and separator == " ":
|
||||||
|
doc = "".join(current_doc)
|
||||||
|
else:
|
||||||
|
doc = separator.join(current_doc)
|
||||||
|
|
||||||
if doc:
|
if doc:
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ class DataType(str, Enum):
|
|||||||
# Web types
|
# Web types
|
||||||
WEBSITE = "website"
|
WEBSITE = "website"
|
||||||
DOCS_SITE = "docs_site"
|
DOCS_SITE = "docs_site"
|
||||||
|
YOUTUBE_VIDEO = "youtube_video"
|
||||||
|
YOUTUBE_CHANNEL = "youtube_channel"
|
||||||
|
|
||||||
# Raw types
|
# Raw types
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
@@ -34,6 +36,7 @@ class DataType(str, Enum):
|
|||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
chunkers = {
|
chunkers = {
|
||||||
|
DataType.PDF_FILE: ("text_chunker", "TextChunker"),
|
||||||
DataType.TEXT_FILE: ("text_chunker", "TextChunker"),
|
DataType.TEXT_FILE: ("text_chunker", "TextChunker"),
|
||||||
DataType.TEXT: ("text_chunker", "TextChunker"),
|
DataType.TEXT: ("text_chunker", "TextChunker"),
|
||||||
DataType.DOCX: ("text_chunker", "DocxChunker"),
|
DataType.DOCX: ("text_chunker", "DocxChunker"),
|
||||||
@@ -45,9 +48,18 @@ class DataType(str, Enum):
|
|||||||
DataType.XML: ("structured_chunker", "XmlChunker"),
|
DataType.XML: ("structured_chunker", "XmlChunker"),
|
||||||
|
|
||||||
DataType.WEBSITE: ("web_chunker", "WebsiteChunker"),
|
DataType.WEBSITE: ("web_chunker", "WebsiteChunker"),
|
||||||
|
DataType.DIRECTORY: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.YOUTUBE_VIDEO: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.YOUTUBE_CHANNEL: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.GITHUB: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.DOCS_SITE: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.MYSQL: ("text_chunker", "TextChunker"),
|
||||||
|
DataType.POSTGRES: ("text_chunker", "TextChunker"),
|
||||||
}
|
}
|
||||||
|
|
||||||
module_name, class_name = chunkers.get(self, ("default_chunker", "DefaultChunker"))
|
if self not in chunkers:
|
||||||
|
raise ValueError(f"No chunker defined for {self}")
|
||||||
|
module_name, class_name = chunkers[self]
|
||||||
module_path = f"crewai_tools.rag.chunkers.{module_name}"
|
module_path = f"crewai_tools.rag.chunkers.{module_name}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -60,6 +72,7 @@ class DataType(str, Enum):
|
|||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
loaders = {
|
loaders = {
|
||||||
|
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
|
||||||
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||||
DataType.TEXT: ("text_loader", "TextLoader"),
|
DataType.TEXT: ("text_loader", "TextLoader"),
|
||||||
DataType.XML: ("xml_loader", "XMLLoader"),
|
DataType.XML: ("xml_loader", "XMLLoader"),
|
||||||
@@ -69,9 +82,17 @@ class DataType(str, Enum):
|
|||||||
DataType.DOCX: ("docx_loader", "DOCXLoader"),
|
DataType.DOCX: ("docx_loader", "DOCXLoader"),
|
||||||
DataType.CSV: ("csv_loader", "CSVLoader"),
|
DataType.CSV: ("csv_loader", "CSVLoader"),
|
||||||
DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"),
|
DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"),
|
||||||
|
DataType.YOUTUBE_VIDEO: ("youtube_video_loader", "YoutubeVideoLoader"),
|
||||||
|
DataType.YOUTUBE_CHANNEL: ("youtube_channel_loader", "YoutubeChannelLoader"),
|
||||||
|
DataType.GITHUB: ("github_loader", "GithubLoader"),
|
||||||
|
DataType.DOCS_SITE: ("docs_site_loader", "DocsSiteLoader"),
|
||||||
|
DataType.MYSQL: ("mysql_loader", "MySQLLoader"),
|
||||||
|
DataType.POSTGRES: ("postgres_loader", "PostgresLoader"),
|
||||||
}
|
}
|
||||||
|
|
||||||
module_name, class_name = loaders.get(self, ("text_loader", "TextLoader"))
|
if self not in loaders:
|
||||||
|
raise ValueError(f"No loader defined for {self}")
|
||||||
|
module_name, class_name = loaders[self]
|
||||||
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||||
try:
|
try:
|
||||||
module = import_module(module_path)
|
module = import_module(module_path)
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ from crewai_tools.rag.loaders.json_loader import JSONLoader
|
|||||||
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
|
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
|
||||||
from crewai_tools.rag.loaders.csv_loader import CSVLoader
|
from crewai_tools.rag.loaders.csv_loader import CSVLoader
|
||||||
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
|
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
|
||||||
|
from crewai_tools.rag.loaders.pdf_loader import PDFLoader
|
||||||
|
from crewai_tools.rag.loaders.youtube_video_loader import YoutubeVideoLoader
|
||||||
|
from crewai_tools.rag.loaders.youtube_channel_loader import YoutubeChannelLoader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TextFileLoader",
|
"TextFileLoader",
|
||||||
@@ -17,4 +20,7 @@ __all__ = [
|
|||||||
"DOCXLoader",
|
"DOCXLoader",
|
||||||
"CSVLoader",
|
"CSVLoader",
|
||||||
"DirectoryLoader",
|
"DirectoryLoader",
|
||||||
|
"PDFLoader",
|
||||||
|
"YoutubeVideoLoader",
|
||||||
|
"YoutubeChannelLoader",
|
||||||
]
|
]
|
||||||
|
|||||||
98
src/crewai_tools/rag/loaders/docs_site_loader.py
Normal file
98
src/crewai_tools/rag/loaders/docs_site_loader.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Documentation site loader."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class DocsSiteLoader(BaseLoader):
|
||||||
|
"""Loader for documentation websites."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load content from a documentation site.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Documentation site URL
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with documentation content
|
||||||
|
"""
|
||||||
|
docs_url = source.source
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(docs_url, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise ValueError(f"Unable to fetch documentation from {docs_url}: {e}")
|
||||||
|
|
||||||
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
|
||||||
|
for script in soup(["script", "style"]):
|
||||||
|
script.decompose()
|
||||||
|
|
||||||
|
title = soup.find("title")
|
||||||
|
title_text = title.get_text(strip=True) if title else "Documentation"
|
||||||
|
|
||||||
|
main_content = None
|
||||||
|
for selector in ["main", "article", '[role="main"]', ".content", "#content", ".documentation"]:
|
||||||
|
main_content = soup.select_one(selector)
|
||||||
|
if main_content:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not main_content:
|
||||||
|
main_content = soup.find("body")
|
||||||
|
|
||||||
|
if not main_content:
|
||||||
|
raise ValueError(f"Unable to extract content from documentation site: {docs_url}")
|
||||||
|
|
||||||
|
text_parts = [f"Title: {title_text}", ""]
|
||||||
|
|
||||||
|
headings = main_content.find_all(["h1", "h2", "h3"])
|
||||||
|
if headings:
|
||||||
|
text_parts.append("Table of Contents:")
|
||||||
|
for heading in headings[:15]:
|
||||||
|
level = int(heading.name[1])
|
||||||
|
indent = " " * (level - 1)
|
||||||
|
text_parts.append(f"{indent}- {heading.get_text(strip=True)}")
|
||||||
|
text_parts.append("")
|
||||||
|
|
||||||
|
text = main_content.get_text(separator="\n", strip=True)
|
||||||
|
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
||||||
|
text_parts.extend(lines)
|
||||||
|
|
||||||
|
nav_links = []
|
||||||
|
for nav_selector in ["nav", ".sidebar", ".toc", ".navigation"]:
|
||||||
|
nav = soup.select_one(nav_selector)
|
||||||
|
if nav:
|
||||||
|
links = nav.find_all("a", href=True)
|
||||||
|
for link in links[:20]:
|
||||||
|
href = link["href"]
|
||||||
|
if not href.startswith(("http://", "https://", "mailto:", "#")):
|
||||||
|
full_url = urljoin(docs_url, href)
|
||||||
|
nav_links.append(f"- {link.get_text(strip=True)}: {full_url}")
|
||||||
|
|
||||||
|
if nav_links:
|
||||||
|
text_parts.append("")
|
||||||
|
text_parts.append("Related documentation pages:")
|
||||||
|
text_parts.extend(nav_links[:10])
|
||||||
|
|
||||||
|
content = "\n".join(text_parts)
|
||||||
|
|
||||||
|
if len(content) > 100000:
|
||||||
|
content = content[:100000] + "\n\n[Content truncated...]"
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"source": docs_url,
|
||||||
|
"title": title_text,
|
||||||
|
"domain": urlparse(docs_url).netloc
|
||||||
|
},
|
||||||
|
doc_id=self.generate_doc_id(source_ref=docs_url, content=content)
|
||||||
|
)
|
||||||
110
src/crewai_tools/rag/loaders/github_loader.py
Normal file
110
src/crewai_tools/rag/loaders/github_loader.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""GitHub repository content loader."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from github import Github, GithubException
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class GithubLoader(BaseLoader):
|
||||||
|
"""Loader for GitHub repository content."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load content from a GitHub repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: GitHub repository URL
|
||||||
|
**kwargs: Additional arguments including gh_token and content_types
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with repository content
|
||||||
|
"""
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
gh_token = metadata.get("gh_token")
|
||||||
|
content_types = metadata.get("content_types", ["code", "repo"])
|
||||||
|
|
||||||
|
repo_url = source.source
|
||||||
|
if not repo_url.startswith("https://github.com/"):
|
||||||
|
raise ValueError(f"Invalid GitHub URL: {repo_url}")
|
||||||
|
|
||||||
|
parts = repo_url.replace("https://github.com/", "").strip("/").split("/")
|
||||||
|
if len(parts) < 2:
|
||||||
|
raise ValueError(f"Invalid GitHub repository URL: {repo_url}")
|
||||||
|
|
||||||
|
repo_name = f"{parts[0]}/{parts[1]}"
|
||||||
|
|
||||||
|
g = Github(gh_token) if gh_token else Github()
|
||||||
|
|
||||||
|
try:
|
||||||
|
repo = g.get_repo(repo_name)
|
||||||
|
except GithubException as e:
|
||||||
|
raise ValueError(f"Unable to access repository {repo_name}: {e}")
|
||||||
|
|
||||||
|
all_content = []
|
||||||
|
|
||||||
|
if "repo" in content_types:
|
||||||
|
all_content.append(f"Repository: {repo.full_name}")
|
||||||
|
all_content.append(f"Description: {repo.description or 'No description'}")
|
||||||
|
all_content.append(f"Language: {repo.language or 'Not specified'}")
|
||||||
|
all_content.append(f"Stars: {repo.stargazers_count}")
|
||||||
|
all_content.append(f"Forks: {repo.forks_count}")
|
||||||
|
all_content.append("")
|
||||||
|
|
||||||
|
if "code" in content_types:
|
||||||
|
try:
|
||||||
|
readme = repo.get_readme()
|
||||||
|
all_content.append("README:")
|
||||||
|
all_content.append(readme.decoded_content.decode("utf-8", errors="ignore"))
|
||||||
|
all_content.append("")
|
||||||
|
except GithubException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
contents = repo.get_contents("")
|
||||||
|
if isinstance(contents, list):
|
||||||
|
all_content.append("Repository structure:")
|
||||||
|
for content_file in contents[:20]:
|
||||||
|
all_content.append(f"- {content_file.path} ({content_file.type})")
|
||||||
|
all_content.append("")
|
||||||
|
except GithubException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if "pr" in content_types:
|
||||||
|
prs = repo.get_pulls(state="open")
|
||||||
|
pr_list = list(prs[:5])
|
||||||
|
if pr_list:
|
||||||
|
all_content.append("Recent Pull Requests:")
|
||||||
|
for pr in pr_list:
|
||||||
|
all_content.append(f"- PR #{pr.number}: {pr.title}")
|
||||||
|
if pr.body:
|
||||||
|
body_preview = pr.body[:200].replace("\n", " ")
|
||||||
|
all_content.append(f" {body_preview}")
|
||||||
|
all_content.append("")
|
||||||
|
|
||||||
|
if "issue" in content_types:
|
||||||
|
issues = repo.get_issues(state="open")
|
||||||
|
issue_list = [i for i in list(issues[:10]) if not i.pull_request][:5]
|
||||||
|
if issue_list:
|
||||||
|
all_content.append("Recent Issues:")
|
||||||
|
for issue in issue_list:
|
||||||
|
all_content.append(f"- Issue #{issue.number}: {issue.title}")
|
||||||
|
if issue.body:
|
||||||
|
body_preview = issue.body[:200].replace("\n", " ")
|
||||||
|
all_content.append(f" {body_preview}")
|
||||||
|
all_content.append("")
|
||||||
|
|
||||||
|
if not all_content:
|
||||||
|
raise ValueError(f"No content could be loaded from repository: {repo_url}")
|
||||||
|
|
||||||
|
content = "\n".join(all_content)
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"source": repo_url,
|
||||||
|
"repo": repo_name,
|
||||||
|
"content_types": content_types
|
||||||
|
},
|
||||||
|
doc_id=self.generate_doc_id(source_ref=repo_url, content=content)
|
||||||
|
)
|
||||||
99
src/crewai_tools/rag/loaders/mysql_loader.py
Normal file
99
src/crewai_tools/rag/loaders/mysql_loader.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""MySQL database loader."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import pymysql
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLLoader(BaseLoader):
|
||||||
|
"""Loader for MySQL database content."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load content from a MySQL database table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: SQL query (e.g., "SELECT * FROM table_name")
|
||||||
|
**kwargs: Additional arguments including db_uri
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with database content
|
||||||
|
"""
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
db_uri = metadata.get("db_uri")
|
||||||
|
|
||||||
|
if not db_uri:
|
||||||
|
raise ValueError("Database URI is required for MySQL loader")
|
||||||
|
|
||||||
|
query = source.source
|
||||||
|
|
||||||
|
parsed = urlparse(db_uri)
|
||||||
|
if parsed.scheme not in ["mysql", "mysql+pymysql"]:
|
||||||
|
raise ValueError(f"Invalid MySQL URI scheme: {parsed.scheme}")
|
||||||
|
|
||||||
|
connection_params = {
|
||||||
|
"host": parsed.hostname or "localhost",
|
||||||
|
"port": parsed.port or 3306,
|
||||||
|
"user": parsed.username,
|
||||||
|
"password": parsed.password,
|
||||||
|
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||||
|
"charset": "utf8mb4",
|
||||||
|
"cursorclass": pymysql.cursors.DictCursor
|
||||||
|
}
|
||||||
|
|
||||||
|
if not connection_params["database"]:
|
||||||
|
raise ValueError("Database name is required in the URI")
|
||||||
|
|
||||||
|
try:
|
||||||
|
connection = pymysql.connect(**connection_params)
|
||||||
|
try:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(query)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
content = "No data found in the table"
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={"source": query, "row_count": 0},
|
||||||
|
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
||||||
|
)
|
||||||
|
|
||||||
|
text_parts = []
|
||||||
|
|
||||||
|
columns = list(rows[0].keys())
|
||||||
|
text_parts.append(f"Columns: {', '.join(columns)}")
|
||||||
|
text_parts.append(f"Total rows: {len(rows)}")
|
||||||
|
text_parts.append("")
|
||||||
|
|
||||||
|
for i, row in enumerate(rows, 1):
|
||||||
|
text_parts.append(f"Row {i}:")
|
||||||
|
for col, val in row.items():
|
||||||
|
if val is not None:
|
||||||
|
text_parts.append(f" {col}: {val}")
|
||||||
|
text_parts.append("")
|
||||||
|
|
||||||
|
content = "\n".join(text_parts)
|
||||||
|
|
||||||
|
if len(content) > 100000:
|
||||||
|
content = content[:100000] + "\n\n[Content truncated...]"
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"source": query,
|
||||||
|
"database": connection_params["database"],
|
||||||
|
"row_count": len(rows),
|
||||||
|
"columns": columns
|
||||||
|
},
|
||||||
|
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
connection.close()
|
||||||
|
except pymysql.Error as e:
|
||||||
|
raise ValueError(f"MySQL database error: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to load data from MySQL: {e}")
|
||||||
72
src/crewai_tools/rag/loaders/pdf_loader.py
Normal file
72
src/crewai_tools/rag/loaders/pdf_loader.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""PDF loader for extracting text from PDF files."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class PDFLoader(BaseLoader):
|
||||||
|
"""Loader for PDF files."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load and extract text from a PDF file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: The source content containing the PDF file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with extracted text content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the PDF file doesn't exist
|
||||||
|
ImportError: If required PDF libraries aren't installed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import pypdf
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import PyPDF2 as pypdf
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"PDF support requires pypdf or PyPDF2. "
|
||||||
|
"Install with: uv add pypdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = source.source
|
||||||
|
|
||||||
|
if not os.path.isfile(file_path):
|
||||||
|
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||||
|
|
||||||
|
text_content = []
|
||||||
|
metadata: dict[str, Any] = {
|
||||||
|
"source": str(file_path),
|
||||||
|
"file_name": Path(file_path).name,
|
||||||
|
"file_type": "pdf"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, 'rb') as file:
|
||||||
|
pdf_reader = pypdf.PdfReader(file)
|
||||||
|
metadata["num_pages"] = len(pdf_reader.pages)
|
||||||
|
|
||||||
|
for page_num, page in enumerate(pdf_reader.pages, 1):
|
||||||
|
page_text = page.extract_text()
|
||||||
|
if page_text.strip():
|
||||||
|
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error reading PDF file {file_path}: {str(e)}")
|
||||||
|
|
||||||
|
if not text_content:
|
||||||
|
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||||
|
else:
|
||||||
|
content = "\n\n".join(text_content)
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
source=str(file_path),
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content)
|
||||||
|
)
|
||||||
99
src/crewai_tools/rag/loaders/postgres_loader.py
Normal file
99
src/crewai_tools/rag/loaders/postgres_loader.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""PostgreSQL database loader."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2.extras import RealDictCursor
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresLoader(BaseLoader):
|
||||||
|
"""Loader for PostgreSQL database content."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load content from a PostgreSQL database table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: SQL query (e.g., "SELECT * FROM table_name")
|
||||||
|
**kwargs: Additional arguments including db_uri
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with database content
|
||||||
|
"""
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
db_uri = metadata.get("db_uri")
|
||||||
|
|
||||||
|
if not db_uri:
|
||||||
|
raise ValueError("Database URI is required for PostgreSQL loader")
|
||||||
|
|
||||||
|
query = source.source
|
||||||
|
|
||||||
|
parsed = urlparse(db_uri)
|
||||||
|
if parsed.scheme not in ["postgresql", "postgres", "postgresql+psycopg2"]:
|
||||||
|
raise ValueError(f"Invalid PostgreSQL URI scheme: {parsed.scheme}")
|
||||||
|
|
||||||
|
connection_params = {
|
||||||
|
"host": parsed.hostname or "localhost",
|
||||||
|
"port": parsed.port or 5432,
|
||||||
|
"user": parsed.username,
|
||||||
|
"password": parsed.password,
|
||||||
|
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||||
|
"cursor_factory": RealDictCursor
|
||||||
|
}
|
||||||
|
|
||||||
|
if not connection_params["database"]:
|
||||||
|
raise ValueError("Database name is required in the URI")
|
||||||
|
|
||||||
|
try:
|
||||||
|
connection = psycopg2.connect(**connection_params)
|
||||||
|
try:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(query)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
content = "No data found in the table"
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={"source": query, "row_count": 0},
|
||||||
|
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
||||||
|
)
|
||||||
|
|
||||||
|
text_parts = []
|
||||||
|
|
||||||
|
columns = list(rows[0].keys())
|
||||||
|
text_parts.append(f"Columns: {', '.join(columns)}")
|
||||||
|
text_parts.append(f"Total rows: {len(rows)}")
|
||||||
|
text_parts.append("")
|
||||||
|
|
||||||
|
for i, row in enumerate(rows, 1):
|
||||||
|
text_parts.append(f"Row {i}:")
|
||||||
|
for col, val in row.items():
|
||||||
|
if val is not None:
|
||||||
|
text_parts.append(f" {col}: {val}")
|
||||||
|
text_parts.append("")
|
||||||
|
|
||||||
|
content = "\n".join(text_parts)
|
||||||
|
|
||||||
|
if len(content) > 100000:
|
||||||
|
content = content[:100000] + "\n\n[Content truncated...]"
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"source": query,
|
||||||
|
"database": connection_params["database"],
|
||||||
|
"row_count": len(rows),
|
||||||
|
"columns": columns
|
||||||
|
},
|
||||||
|
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
connection.close()
|
||||||
|
except psycopg2.Error as e:
|
||||||
|
raise ValueError(f"PostgreSQL database error: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to load data from PostgreSQL: {e}")
|
||||||
@@ -11,7 +11,7 @@ class XMLLoader(BaseLoader):
|
|||||||
|
|
||||||
if source_content.is_url():
|
if source_content.is_url():
|
||||||
content = self._load_from_url(source_ref, kwargs)
|
content = self._load_from_url(source_ref, kwargs)
|
||||||
elif os.path.exists(source_ref):
|
elif source_content.path_exists():
|
||||||
content = self._load_from_file(source_ref)
|
content = self._load_from_file(source_ref)
|
||||||
|
|
||||||
return self._parse_xml(content, source_ref)
|
return self._parse_xml(content, source_ref)
|
||||||
|
|||||||
141
src/crewai_tools/rag/loaders/youtube_channel_loader.py
Normal file
141
src/crewai_tools/rag/loaders/youtube_channel_loader.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""YouTube channel loader for extracting content from YouTube channels."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class YoutubeChannelLoader(BaseLoader):
|
||||||
|
"""Loader for YouTube channels."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load and extract content from a YouTube channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: The source content containing the YouTube channel URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with channel content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If required YouTube libraries aren't installed
|
||||||
|
ValueError: If the URL is not a valid YouTube channel URL
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from pytube import Channel
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"YouTube channel support requires pytube. "
|
||||||
|
"Install with: uv add pytube"
|
||||||
|
)
|
||||||
|
|
||||||
|
channel_url = source.source
|
||||||
|
|
||||||
|
if not any(pattern in channel_url for pattern in ['youtube.com/channel/', 'youtube.com/c/', 'youtube.com/@', 'youtube.com/user/']):
|
||||||
|
raise ValueError(f"Invalid YouTube channel URL: {channel_url}")
|
||||||
|
|
||||||
|
metadata: dict[str, Any] = {
|
||||||
|
"source": channel_url,
|
||||||
|
"data_type": "youtube_channel"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
channel = Channel(channel_url)
|
||||||
|
|
||||||
|
metadata["channel_name"] = channel.channel_name
|
||||||
|
metadata["channel_id"] = channel.channel_id
|
||||||
|
|
||||||
|
max_videos = kwargs.get('max_videos', 10)
|
||||||
|
video_urls = list(channel.video_urls)[:max_videos]
|
||||||
|
metadata["num_videos_loaded"] = len(video_urls)
|
||||||
|
metadata["total_videos"] = len(list(channel.video_urls))
|
||||||
|
|
||||||
|
content_parts = [
|
||||||
|
f"YouTube Channel: {channel.channel_name}",
|
||||||
|
f"Channel ID: {channel.channel_id}",
|
||||||
|
f"Total Videos: {metadata['total_videos']}",
|
||||||
|
f"Videos Loaded: {metadata['num_videos_loaded']}",
|
||||||
|
"\n--- Video Summaries ---\n"
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from youtube_transcript_api import YouTubeTranscriptApi
|
||||||
|
from pytube import YouTube
|
||||||
|
|
||||||
|
for i, video_url in enumerate(video_urls, 1):
|
||||||
|
try:
|
||||||
|
video_id = self._extract_video_id(video_url)
|
||||||
|
if not video_id:
|
||||||
|
continue
|
||||||
|
yt = YouTube(video_url)
|
||||||
|
title = yt.title or f"Video {i}"
|
||||||
|
description = yt.description[:200] if yt.description else "No description"
|
||||||
|
|
||||||
|
content_parts.append(f"\n{i}. {title}")
|
||||||
|
content_parts.append(f" URL: {video_url}")
|
||||||
|
content_parts.append(f" Description: {description}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api = YouTubeTranscriptApi()
|
||||||
|
transcript_list = api.list(video_id)
|
||||||
|
transcript = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
transcript = transcript_list.find_transcript(['en'])
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
transcript = transcript_list.find_generated_transcript(['en'])
|
||||||
|
except:
|
||||||
|
transcript = next(iter(transcript_list), None)
|
||||||
|
|
||||||
|
if transcript:
|
||||||
|
transcript_data = transcript.fetch()
|
||||||
|
text_parts = []
|
||||||
|
char_count = 0
|
||||||
|
for entry in transcript_data:
|
||||||
|
text = entry.text.strip() if hasattr(entry, 'text') else ''
|
||||||
|
if text:
|
||||||
|
text_parts.append(text)
|
||||||
|
char_count += len(text)
|
||||||
|
if char_count > 500:
|
||||||
|
break
|
||||||
|
|
||||||
|
if text_parts:
|
||||||
|
preview = ' '.join(text_parts)[:500]
|
||||||
|
content_parts.append(f" Transcript Preview: {preview}...")
|
||||||
|
except:
|
||||||
|
content_parts.append(" Transcript: Not available")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
content_parts.append(f"\n{i}. Error loading video: {str(e)}")
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
for i, video_url in enumerate(video_urls, 1):
|
||||||
|
content_parts.append(f"\n{i}. {video_url}")
|
||||||
|
|
||||||
|
content = '\n'.join(content_parts)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Unable to load YouTube channel {channel_url}: {str(e)}") from e
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
source=channel_url,
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=channel_url, content=content)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_video_id(self, url: str) -> str | None:
|
||||||
|
"""Extract video ID from YouTube URL."""
|
||||||
|
patterns = [
|
||||||
|
r'(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)',
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
match = re.search(pattern, url)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
|
||||||
|
return None
|
||||||
123
src/crewai_tools/rag/loaders/youtube_video_loader.py
Normal file
123
src/crewai_tools/rag/loaders/youtube_video_loader.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""YouTube video loader for extracting transcripts from YouTube videos."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
|
||||||
|
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||||
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
|
|
||||||
|
|
||||||
|
class YoutubeVideoLoader(BaseLoader):
|
||||||
|
"""Loader for YouTube videos."""
|
||||||
|
|
||||||
|
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||||
|
"""Load and extract transcript from a YouTube video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: The source content containing the YouTube URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with transcript content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If required YouTube libraries aren't installed
|
||||||
|
ValueError: If the URL is not a valid YouTube video URL
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from youtube_transcript_api import YouTubeTranscriptApi
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"YouTube support requires youtube-transcript-api. "
|
||||||
|
"Install with: uv add youtube-transcript-api"
|
||||||
|
)
|
||||||
|
|
||||||
|
video_url = source.source
|
||||||
|
video_id = self._extract_video_id(video_url)
|
||||||
|
|
||||||
|
if not video_id:
|
||||||
|
raise ValueError(f"Invalid YouTube URL: {video_url}")
|
||||||
|
|
||||||
|
metadata: dict[str, Any] = {
|
||||||
|
"source": video_url,
|
||||||
|
"video_id": video_id,
|
||||||
|
"data_type": "youtube_video"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
api = YouTubeTranscriptApi()
|
||||||
|
transcript_list = api.list(video_id)
|
||||||
|
|
||||||
|
transcript = None
|
||||||
|
try:
|
||||||
|
transcript = transcript_list.find_transcript(['en'])
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
transcript = transcript_list.find_generated_transcript(['en'])
|
||||||
|
except:
|
||||||
|
transcript = next(iter(transcript_list))
|
||||||
|
|
||||||
|
if transcript:
|
||||||
|
metadata["language"] = transcript.language
|
||||||
|
metadata["is_generated"] = transcript.is_generated
|
||||||
|
|
||||||
|
transcript_data = transcript.fetch()
|
||||||
|
|
||||||
|
text_content = []
|
||||||
|
for entry in transcript_data:
|
||||||
|
text = entry.text.strip() if hasattr(entry, 'text') else ''
|
||||||
|
if text:
|
||||||
|
text_content.append(text)
|
||||||
|
|
||||||
|
content = ' '.join(text_content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pytube import YouTube
|
||||||
|
yt = YouTube(video_url)
|
||||||
|
metadata["title"] = yt.title
|
||||||
|
metadata["author"] = yt.author
|
||||||
|
metadata["length_seconds"] = yt.length
|
||||||
|
metadata["description"] = yt.description[:500] if yt.description else None
|
||||||
|
|
||||||
|
if yt.title:
|
||||||
|
content = f"Title: {yt.title}\n\nAuthor: {yt.author or 'Unknown'}\n\nTranscript:\n{content}"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No transcript available for YouTube video: {video_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Unable to extract transcript from YouTube video {video_id}: {str(e)}") from e
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
source=video_url,
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id=self.generate_doc_id(source_ref=video_url, content=content)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_video_id(self, url: str) -> str | None:
|
||||||
|
"""Extract video ID from various YouTube URL formats."""
|
||||||
|
patterns = [
|
||||||
|
r'(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)',
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
match = re.search(pattern, url)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
hostname = parsed.hostname
|
||||||
|
if hostname:
|
||||||
|
hostname_lower = hostname.lower()
|
||||||
|
# Allow youtube.com and any subdomain of youtube.com, plus youtu.be shortener
|
||||||
|
if hostname_lower == 'youtube.com' or hostname_lower.endswith('.youtube.com') or hostname_lower == 'youtu.be':
|
||||||
|
query_params = parse_qs(parsed.query)
|
||||||
|
if 'v' in query_params:
|
||||||
|
return query_params['v'][0]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -1,4 +1,29 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
def compute_sha256(content: str) -> str:
|
def compute_sha256(content: str) -> str:
|
||||||
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
def sanitize_metadata_for_chromadb(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Sanitize metadata to ensure ChromaDB compatibility.
|
||||||
|
|
||||||
|
ChromaDB only accepts str, int, float, or bool values in metadata.
|
||||||
|
This function converts other types to strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: Dictionary of metadata to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized metadata dictionary with only ChromaDB-compatible types
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||||
|
sanitized[key] = value
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
# Convert lists/tuples to pipe-separated strings
|
||||||
|
sanitized[key] = " | ".join(str(v) for v in value)
|
||||||
|
else:
|
||||||
|
# Convert other types to string
|
||||||
|
sanitized[key] = str(value)
|
||||||
|
return sanitized
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
from typing import Any, Optional, Type
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain.models.data_type import DataType
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class FixedCodeDocsSearchToolSchema(BaseModel):
|
class FixedCodeDocsSearchToolSchema(BaseModel):
|
||||||
@@ -42,15 +38,15 @@ class CodeDocsSearchTool(RagTool):
|
|||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
def add(self, docs_url: str) -> None:
|
def add(self, docs_url: str) -> None:
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().add(docs_url, data_type=DataType.DOCS_SITE)
|
super().add(docs_url, data_type=DataType.DOCS_SITE)
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
docs_url: Optional[str] = None,
|
docs_url: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if docs_url is not None:
|
if docs_url is not None:
|
||||||
self.add(docs_url)
|
self.add(docs_url)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain.models.data_type import DataType
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class FixedCSVSearchToolSchema(BaseModel):
|
class FixedCSVSearchToolSchema(BaseModel):
|
||||||
@@ -42,15 +38,16 @@ class CSVSearchTool(RagTool):
|
|||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
def add(self, csv: str) -> None:
|
def add(self, csv: str) -> None:
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().add(csv, data_type=DataType.CSV)
|
super().add(csv, data_type=DataType.CSV)
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
csv: Optional[str] = None,
|
csv: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if csv is not None:
|
if csv is not None:
|
||||||
self.add(csv)
|
self.add(csv)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,9 @@
|
|||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class FixedDirectorySearchToolSchema(BaseModel):
|
class FixedDirectorySearchToolSchema(BaseModel):
|
||||||
@@ -34,8 +29,6 @@ class DirectorySearchTool(RagTool):
|
|||||||
args_schema: Type[BaseModel] = DirectorySearchToolSchema
|
args_schema: Type[BaseModel] = DirectorySearchToolSchema
|
||||||
|
|
||||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if directory is not None:
|
if directory is not None:
|
||||||
self.add(directory)
|
self.add(directory)
|
||||||
@@ -44,16 +37,15 @@ class DirectorySearchTool(RagTool):
|
|||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
def add(self, directory: str) -> None:
|
def add(self, directory: str) -> None:
|
||||||
super().add(
|
super().add(directory, data_type=DataType.DIRECTORY)
|
||||||
directory,
|
|
||||||
loader=DirectoryLoader(config=dict(recursive=True)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
directory: Optional[str] = None,
|
directory: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if directory is not None:
|
if directory is not None:
|
||||||
self.add(directory)
|
self.add(directory)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
from typing import Any, Optional, Type
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain.models.data_type import DataType
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class FixedDOCXSearchToolSchema(BaseModel):
|
class FixedDOCXSearchToolSchema(BaseModel):
|
||||||
@@ -48,15 +44,15 @@ class DOCXSearchTool(RagTool):
|
|||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
def add(self, docx: str) -> None:
|
def add(self, docx: str) -> None:
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().add(docx, data_type=DataType.DOCX)
|
super().add(docx, data_type=DataType.DOCX)
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
docx: Optional[str] = None,
|
docx: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if docx is not None:
|
if docx is not None:
|
||||||
self.add(docx)
|
self.add(docx)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -1,14 +1,9 @@
|
|||||||
from typing import List, Optional, Type, Any
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
try:
|
from pydantic import BaseModel, Field
|
||||||
from embedchain.loaders.github import GithubLoader
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class FixedGithubSearchToolSchema(BaseModel):
|
class FixedGithubSearchToolSchema(BaseModel):
|
||||||
@@ -42,7 +37,6 @@ class GithubSearchTool(RagTool):
|
|||||||
default_factory=lambda: ["code", "repo", "pr", "issue"],
|
default_factory=lambda: ["code", "repo", "pr", "issue"],
|
||||||
description="Content types you want to be included search, options: [code, repo, pr, issue]",
|
description="Content types you want to be included search, options: [code, repo, pr, issue]",
|
||||||
)
|
)
|
||||||
_loader: Any | None = PrivateAttr(default=None)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -50,10 +44,7 @@ class GithubSearchTool(RagTool):
|
|||||||
content_types: Optional[List[str]] = None,
|
content_types: Optional[List[str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._loader = GithubLoader(config={"token": self.gh_token})
|
|
||||||
|
|
||||||
if github_repo and content_types:
|
if github_repo and content_types:
|
||||||
self.add(repo=github_repo, content_types=content_types)
|
self.add(repo=github_repo, content_types=content_types)
|
||||||
@@ -67,11 +58,10 @@ class GithubSearchTool(RagTool):
|
|||||||
content_types: Optional[List[str]] = None,
|
content_types: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
content_types = content_types or self.content_types
|
content_types = content_types or self.content_types
|
||||||
|
|
||||||
super().add(
|
super().add(
|
||||||
f"repo:{repo} type:{','.join(content_types)}",
|
f"https://github.com/{repo}",
|
||||||
data_type="github",
|
data_type=DataType.GITHUB,
|
||||||
loader=self._loader,
|
metadata={"content_types": content_types, "gh_token": self.gh_token}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
@@ -79,10 +69,12 @@ class GithubSearchTool(RagTool):
|
|||||||
search_query: str,
|
search_query: str,
|
||||||
github_repo: Optional[str] = None,
|
github_repo: Optional[str] = None,
|
||||||
content_types: Optional[List[str]] = None,
|
content_types: Optional[List[str]] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if github_repo:
|
if github_repo:
|
||||||
self.add(
|
self.add(
|
||||||
repo=github_repo,
|
repo=github_repo,
|
||||||
content_types=content_types,
|
content_types=content_types,
|
||||||
)
|
)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -41,7 +41,9 @@ class JSONSearchTool(RagTool):
|
|||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
json_path: Optional[str] = None,
|
json_path: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if json_path is not None:
|
if json_path is not None:
|
||||||
self.add(json_path)
|
self.add(json_path)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -2,13 +2,9 @@ from typing import Optional, Type
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain.models.data_type import DataType
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class FixedMDXSearchToolSchema(BaseModel):
|
class FixedMDXSearchToolSchema(BaseModel):
|
||||||
@@ -42,15 +38,15 @@ class MDXSearchTool(RagTool):
|
|||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
def add(self, mdx: str) -> None:
|
def add(self, mdx: str) -> None:
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().add(mdx, data_type=DataType.MDX)
|
super().add(mdx, data_type=DataType.MDX)
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
mdx: Optional[str] = None,
|
mdx: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if mdx is not None:
|
if mdx is not None:
|
||||||
self.add(mdx)
|
self.add(mdx)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -1,14 +1,9 @@
|
|||||||
from typing import Any, Type
|
from typing import Any, Type
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain.loaders.mysql import MySQLLoader
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class MySQLSearchToolSchema(BaseModel):
|
class MySQLSearchToolSchema(BaseModel):
|
||||||
@@ -27,12 +22,8 @@ class MySQLSearchTool(RagTool):
|
|||||||
db_uri: str = Field(..., description="Mandatory database URI")
|
db_uri: str = Field(..., description="Mandatory database URI")
|
||||||
|
|
||||||
def __init__(self, table_name: str, **kwargs):
|
def __init__(self, table_name: str, **kwargs):
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
kwargs["data_type"] = "mysql"
|
self.add(table_name, data_type=DataType.MYSQL, metadata={"db_uri": self.db_uri})
|
||||||
kwargs["loader"] = MySQLLoader(config=dict(url=self.db_uri))
|
|
||||||
self.add(table_name)
|
|
||||||
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
@@ -46,6 +37,8 @@ class MySQLSearchTool(RagTool):
|
|||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -2,13 +2,8 @@ from typing import Optional, Type
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain.models.data_type import DataType
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class FixedPDFSearchToolSchema(BaseModel):
|
class FixedPDFSearchToolSchema(BaseModel):
|
||||||
@@ -41,15 +36,15 @@ class PDFSearchTool(RagTool):
|
|||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
def add(self, pdf: str) -> None:
|
def add(self, pdf: str) -> None:
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().add(pdf, data_type=DataType.PDF_FILE)
|
super().add(pdf, data_type=DataType.PDF_FILE)
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
pdf: Optional[str] = None,
|
pdf: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if pdf is not None:
|
if pdf is not None:
|
||||||
self.add(pdf)
|
self.add(pdf)
|
||||||
return super()._run(query=query)
|
return super()._run(query=query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -1,14 +1,9 @@
|
|||||||
from typing import Any, Type
|
from typing import Any, Type
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain.loaders.postgres import PostgresLoader
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class PGSearchToolSchema(BaseModel):
|
class PGSearchToolSchema(BaseModel):
|
||||||
@@ -27,12 +22,8 @@ class PGSearchTool(RagTool):
|
|||||||
db_uri: str = Field(..., description="Mandatory database URI")
|
db_uri: str = Field(..., description="Mandatory database URI")
|
||||||
|
|
||||||
def __init__(self, table_name: str, **kwargs):
|
def __init__(self, table_name: str, **kwargs):
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
kwargs["data_type"] = "postgres"
|
self.add(table_name, data_type=DataType.POSTGRES, metadata={"db_uri": self.db_uri})
|
||||||
kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
|
|
||||||
self.add(table_name)
|
|
||||||
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
@@ -46,6 +37,8 @@ class PGSearchTool(RagTool):
|
|||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
return super()._run(query=search_query, **kwargs)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit, **kwargs)
|
||||||
|
|||||||
@@ -1,17 +1,22 @@
|
|||||||
import portalocker
|
import os
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
||||||
|
|
||||||
|
from crewai.rag.embeddings.factory import get_embedding_function
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
class Adapter(BaseModel, ABC):
|
class Adapter(BaseModel, ABC):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def query(self, question: str) -> str:
|
def query(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> str:
|
||||||
"""Query the knowledge base with a question and return the answer."""
|
"""Query the knowledge base with a question and return the answer."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -25,7 +30,12 @@ class Adapter(BaseModel, ABC):
|
|||||||
|
|
||||||
class RagTool(BaseTool):
|
class RagTool(BaseTool):
|
||||||
class _AdapterPlaceholder(Adapter):
|
class _AdapterPlaceholder(Adapter):
|
||||||
def query(self, question: str) -> str:
|
def query(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def add(self, *args: Any, **kwargs: Any) -> None:
|
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||||
@@ -34,28 +44,149 @@ class RagTool(BaseTool):
|
|||||||
name: str = "Knowledge base"
|
name: str = "Knowledge base"
|
||||||
description: str = "A knowledge base that can be used to answer questions."
|
description: str = "A knowledge base that can be used to answer questions."
|
||||||
summarize: bool = False
|
summarize: bool = False
|
||||||
|
similarity_threshold: float = 0.6
|
||||||
|
limit: int = 5
|
||||||
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
||||||
config: dict[str, Any] | None = None
|
config: Any | None = None
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _set_default_adapter(self):
|
def _set_default_adapter(self):
|
||||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||||
try:
|
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||||
from embedchain import App
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
|
|
||||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
parsed_config = self._parse_config(self.config)
|
||||||
|
|
||||||
with portalocker.Lock("crewai-rag-tool.lock", timeout=10):
|
self.adapter = CrewAIRagAdapter(
|
||||||
app = App.from_config(config=self.config) if self.config else App()
|
collection_name="rag_tool_collection",
|
||||||
|
summarize=self.summarize,
|
||||||
self.adapter = EmbedchainAdapter(
|
similarity_threshold=self.similarity_threshold,
|
||||||
embedchain_app=app, summarize=self.summarize
|
limit=self.limit,
|
||||||
|
config=parsed_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _parse_config(self, config: Any) -> Any:
|
||||||
|
"""Parse complex config format to extract provider-specific config.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the config format is invalid or uses unsupported providers.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(config, dict) and "provider" in config:
|
||||||
|
return config
|
||||||
|
|
||||||
|
if isinstance(config, dict):
|
||||||
|
if "vectordb" in config:
|
||||||
|
vectordb_config = config["vectordb"]
|
||||||
|
if isinstance(vectordb_config, dict) and "provider" in vectordb_config:
|
||||||
|
provider = vectordb_config["provider"]
|
||||||
|
provider_config = vectordb_config.get("config", {})
|
||||||
|
|
||||||
|
supported_providers = ["chromadb", "qdrant"]
|
||||||
|
if provider not in supported_providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported vector database provider: '{provider}'. "
|
||||||
|
f"CrewAI RAG currently supports: {', '.join(supported_providers)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_config = config.get("embedding_model")
|
||||||
|
embedding_function = None
|
||||||
|
if embedding_config and isinstance(embedding_config, dict):
|
||||||
|
embedding_function = self._create_embedding_function(
|
||||||
|
embedding_config, provider
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._create_provider_config(
|
||||||
|
provider, provider_config, embedding_function
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
embedding_config = config.get("embedding_model")
|
||||||
|
embedding_function = None
|
||||||
|
if embedding_config and isinstance(embedding_config, dict):
|
||||||
|
embedding_function = self._create_embedding_function(
|
||||||
|
embedding_config, "chromadb"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._create_provider_config("chromadb", {}, embedding_function)
|
||||||
|
return config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_embedding_function(embedding_config: dict, provider: str) -> Any:
|
||||||
|
"""Create embedding function for the specified vector database provider."""
|
||||||
|
embedding_provider = embedding_config.get("provider")
|
||||||
|
embedding_model_config = embedding_config.get("config", {}).copy()
|
||||||
|
|
||||||
|
if "model" in embedding_model_config:
|
||||||
|
embedding_model_config["model_name"] = embedding_model_config.pop("model")
|
||||||
|
|
||||||
|
factory_config = {"provider": embedding_provider, **embedding_model_config}
|
||||||
|
|
||||||
|
if embedding_provider == "openai" and "api_key" not in factory_config:
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if api_key:
|
||||||
|
factory_config["api_key"] = api_key
|
||||||
|
|
||||||
|
print(f"Creating embedding function with config: {factory_config}")
|
||||||
|
|
||||||
|
if provider == "chromadb":
|
||||||
|
embedding_func = get_embedding_function(factory_config)
|
||||||
|
print(f"Created embedding function: {embedding_func}")
|
||||||
|
print(f"Embedding function type: {type(embedding_func)}")
|
||||||
|
return embedding_func
|
||||||
|
|
||||||
|
elif provider == "qdrant":
|
||||||
|
chromadb_func = get_embedding_function(factory_config)
|
||||||
|
|
||||||
|
def qdrant_embed_fn(text: str) -> list[float]:
|
||||||
|
"""Embed text using ChromaDB function and convert to list of floats for Qdrant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The input text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of floats representing the embedding.
|
||||||
|
"""
|
||||||
|
embeddings = chromadb_func([text])
|
||||||
|
return embeddings[0] if embeddings and len(embeddings) > 0 else []
|
||||||
|
|
||||||
|
return cast(Any, qdrant_embed_fn)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_provider_config(
|
||||||
|
provider: str, provider_config: dict, embedding_function: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Create proper provider config object."""
|
||||||
|
if provider == "chromadb":
|
||||||
|
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||||
|
|
||||||
|
config_kwargs = {}
|
||||||
|
if embedding_function:
|
||||||
|
config_kwargs["embedding_function"] = embedding_function
|
||||||
|
|
||||||
|
config_kwargs.update(provider_config)
|
||||||
|
|
||||||
|
return ChromaDBConfig(**config_kwargs)
|
||||||
|
|
||||||
|
elif provider == "qdrant":
|
||||||
|
from crewai.rag.qdrant.config import QdrantConfig
|
||||||
|
|
||||||
|
config_kwargs = {}
|
||||||
|
if embedding_function:
|
||||||
|
config_kwargs["embedding_function"] = embedding_function
|
||||||
|
|
||||||
|
config_kwargs.update(provider_config)
|
||||||
|
|
||||||
|
return QdrantConfig(**config_kwargs)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@@ -66,5 +197,13 @@ class RagTool(BaseTool):
|
|||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
return f"Relevant Content:\n{self.adapter.query(query)}"
|
threshold = (
|
||||||
|
similarity_threshold
|
||||||
|
if similarity_threshold is not None
|
||||||
|
else self.similarity_threshold
|
||||||
|
)
|
||||||
|
result_limit = limit if limit is not None else self.limit
|
||||||
|
return f"Relevant Content:\n{self.adapter.query(query, similarity_threshold=threshold, limit=result_limit)}"
|
||||||
|
|||||||
@@ -39,7 +39,9 @@ class TXTSearchTool(RagTool):
|
|||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
txt: Optional[str] = None,
|
txt: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if txt is not None:
|
if txt is not None:
|
||||||
self.add(txt)
|
self.add(txt)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -1,14 +1,9 @@
|
|||||||
from typing import Any, Optional, Type
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain.models.data_type import DataType
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class FixedWebsiteSearchToolSchema(BaseModel):
|
class FixedWebsiteSearchToolSchema(BaseModel):
|
||||||
@@ -44,15 +39,15 @@ class WebsiteSearchTool(RagTool):
|
|||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
def add(self, website: str) -> None:
|
def add(self, website: str) -> None:
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
super().add(website, data_type=DataType.WEBSITE)
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().add(website, data_type=DataType.WEB_PAGE)
|
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
website: Optional[str] = None,
|
website: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if website is not None:
|
if website is not None:
|
||||||
self.add(website)
|
self.add(website)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -39,7 +39,9 @@ class XMLSearchTool(RagTool):
|
|||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
xml: Optional[str] = None,
|
xml: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if xml is not None:
|
if xml is not None:
|
||||||
self.add(xml)
|
self.add(xml)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -1,14 +1,9 @@
|
|||||||
from typing import Any, Optional, Type
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain.models.data_type import DataType
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class FixedYoutubeChannelSearchToolSchema(BaseModel):
|
class FixedYoutubeChannelSearchToolSchema(BaseModel):
|
||||||
@@ -55,7 +50,9 @@ class YoutubeChannelSearchTool(RagTool):
|
|||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
youtube_channel_handle: Optional[str] = None,
|
youtube_channel_handle: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if youtube_channel_handle is not None:
|
if youtube_channel_handle is not None:
|
||||||
self.add(youtube_channel_handle)
|
self.add(youtube_channel_handle)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
from typing import Any, Optional, Type
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
try:
|
|
||||||
from embedchain.models.data_type import DataType
|
|
||||||
EMBEDCHAIN_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
EMBEDCHAIN_AVAILABLE = False
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class FixedYoutubeVideoSearchToolSchema(BaseModel):
|
class FixedYoutubeVideoSearchToolSchema(BaseModel):
|
||||||
@@ -44,15 +40,15 @@ class YoutubeVideoSearchTool(RagTool):
|
|||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
def add(self, youtube_video_url: str) -> None:
|
def add(self, youtube_video_url: str) -> None:
|
||||||
if not EMBEDCHAIN_AVAILABLE:
|
|
||||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
|
||||||
super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
|
super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
search_query: str,
|
||||||
youtube_video_url: Optional[str] = None,
|
youtube_video_url: Optional[str] = None,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if youtube_video_url is not None:
|
if youtube_video_url is not None:
|
||||||
self.add(youtube_video_url)
|
self.add(youtube_video_url)
|
||||||
return super()._run(query=search_query)
|
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||||
|
|||||||
@@ -1,43 +1,54 @@
|
|||||||
import os
|
from tempfile import TemporaryDirectory
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from unittest import mock
|
from pathlib import Path
|
||||||
|
|
||||||
from pytest import fixture
|
|
||||||
|
|
||||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
@fixture(autouse=True)
|
def test_rag_tool_initialization():
|
||||||
def mock_embedchain_db_uri():
|
"""Test that RagTool initializes with CrewAI adapter by default."""
|
||||||
with NamedTemporaryFile() as tmp:
|
|
||||||
uri = f"sqlite:///{tmp.name}"
|
|
||||||
with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_llm_and_embedder():
|
|
||||||
class MyTool(RagTool):
|
class MyTool(RagTool):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
tool = MyTool(
|
tool = MyTool()
|
||||||
config=dict(
|
|
||||||
llm=dict(
|
|
||||||
provider="openai",
|
|
||||||
config=dict(model="gpt-3.5-custom"),
|
|
||||||
),
|
|
||||||
embedder=dict(
|
|
||||||
provider="openai",
|
|
||||||
config=dict(model="text-embedding-3-custom"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert tool.adapter is not None
|
assert tool.adapter is not None
|
||||||
assert isinstance(tool.adapter, EmbedchainAdapter)
|
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||||
|
|
||||||
|
adapter = cast(CrewAIRagAdapter, tool.adapter)
|
||||||
|
assert adapter.collection_name == "rag_tool_collection"
|
||||||
|
assert adapter._client is not None
|
||||||
|
|
||||||
adapter = cast(EmbedchainAdapter, tool.adapter)
|
|
||||||
assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom"
|
def test_rag_tool_add_and_query():
|
||||||
assert (
|
"""Test adding content and querying with RagTool."""
|
||||||
adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom"
|
class MyTool(RagTool):
|
||||||
)
|
pass
|
||||||
|
|
||||||
|
tool = MyTool()
|
||||||
|
|
||||||
|
tool.add("The sky is blue on a clear day.")
|
||||||
|
tool.add("Machine learning is a subset of artificial intelligence.")
|
||||||
|
|
||||||
|
result = tool._run(query="What color is the sky?")
|
||||||
|
assert "Relevant Content:" in result
|
||||||
|
|
||||||
|
result = tool._run(query="Tell me about machine learning")
|
||||||
|
assert "Relevant Content:" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_tool_with_file():
|
||||||
|
"""Test RagTool with file content."""
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
test_file = Path(tmpdir) / "test.txt"
|
||||||
|
test_file.write_text("Python is a programming language known for its simplicity.")
|
||||||
|
|
||||||
|
class MyTool(RagTool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = MyTool()
|
||||||
|
tool.add(str(test_file))
|
||||||
|
|
||||||
|
result = tool._run(query="What is Python?")
|
||||||
|
assert "Relevant Content:" in result
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import ANY, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from embedchain.models.data_type import DataType
|
|
||||||
|
|
||||||
|
from crewai_tools.rag.data_types import DataType
|
||||||
from crewai_tools.tools import (
|
from crewai_tools.tools import (
|
||||||
CodeDocsSearchTool,
|
CodeDocsSearchTool,
|
||||||
CSVSearchTool,
|
CSVSearchTool,
|
||||||
@@ -49,7 +49,7 @@ def test_pdf_search_tool(mock_adapter):
|
|||||||
result = tool._run(query="test content")
|
result = tool._run(query="test content")
|
||||||
assert "this is a test" in result.lower()
|
assert "this is a test" in result.lower()
|
||||||
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
||||||
mock_adapter.query.assert_called_once_with("test content")
|
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
mock_adapter.query.reset_mock()
|
mock_adapter.query.reset_mock()
|
||||||
mock_adapter.add.reset_mock()
|
mock_adapter.add.reset_mock()
|
||||||
@@ -58,7 +58,7 @@ def test_pdf_search_tool(mock_adapter):
|
|||||||
result = tool._run(pdf="test.pdf", query="test content")
|
result = tool._run(pdf="test.pdf", query="test content")
|
||||||
assert "this is a test" in result.lower()
|
assert "this is a test" in result.lower()
|
||||||
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
||||||
mock_adapter.query.assert_called_once_with("test content")
|
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
|
|
||||||
def test_txt_search_tool():
|
def test_txt_search_tool():
|
||||||
@@ -82,7 +82,7 @@ def test_docx_search_tool(mock_adapter):
|
|||||||
result = tool._run(search_query="test content")
|
result = tool._run(search_query="test content")
|
||||||
assert "this is a test" in result.lower()
|
assert "this is a test" in result.lower()
|
||||||
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
||||||
mock_adapter.query.assert_called_once_with("test content")
|
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
mock_adapter.query.reset_mock()
|
mock_adapter.query.reset_mock()
|
||||||
mock_adapter.add.reset_mock()
|
mock_adapter.add.reset_mock()
|
||||||
@@ -91,7 +91,7 @@ def test_docx_search_tool(mock_adapter):
|
|||||||
result = tool._run(docx="test.docx", search_query="test content")
|
result = tool._run(docx="test.docx", search_query="test content")
|
||||||
assert "this is a test" in result.lower()
|
assert "this is a test" in result.lower()
|
||||||
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
||||||
mock_adapter.query.assert_called_once_with("test content")
|
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
|
|
||||||
def test_json_search_tool():
|
def test_json_search_tool():
|
||||||
@@ -114,7 +114,7 @@ def test_xml_search_tool(mock_adapter):
|
|||||||
result = tool._run(search_query="test XML", xml="test.xml")
|
result = tool._run(search_query="test XML", xml="test.xml")
|
||||||
assert "this is a test" in result.lower()
|
assert "this is a test" in result.lower()
|
||||||
mock_adapter.add.assert_called_once_with("test.xml")
|
mock_adapter.add.assert_called_once_with("test.xml")
|
||||||
mock_adapter.query.assert_called_once_with("test XML")
|
mock_adapter.query.assert_called_once_with("test XML", similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
|
|
||||||
def test_csv_search_tool():
|
def test_csv_search_tool():
|
||||||
@@ -153,8 +153,8 @@ def test_website_search_tool(mock_adapter):
|
|||||||
tool = WebsiteSearchTool(website=website, adapter=mock_adapter)
|
tool = WebsiteSearchTool(website=website, adapter=mock_adapter)
|
||||||
result = tool._run(search_query=search_query)
|
result = tool._run(search_query=search_query)
|
||||||
|
|
||||||
mock_adapter.query.assert_called_once_with("what is crewai?")
|
mock_adapter.query.assert_called_once_with("what is crewai?", similarity_threshold=0.6, limit=5)
|
||||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
|
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE)
|
||||||
|
|
||||||
assert "this is a test" in result.lower()
|
assert "this is a test" in result.lower()
|
||||||
|
|
||||||
@@ -164,8 +164,8 @@ def test_website_search_tool(mock_adapter):
|
|||||||
tool = WebsiteSearchTool(adapter=mock_adapter)
|
tool = WebsiteSearchTool(adapter=mock_adapter)
|
||||||
result = tool._run(website=website, search_query=search_query)
|
result = tool._run(website=website, search_query=search_query)
|
||||||
|
|
||||||
mock_adapter.query.assert_called_once_with("what is crewai?")
|
mock_adapter.query.assert_called_once_with("what is crewai?", similarity_threshold=0.6, limit=5)
|
||||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
|
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE)
|
||||||
|
|
||||||
assert "this is a test" in result.lower()
|
assert "this is a test" in result.lower()
|
||||||
|
|
||||||
@@ -185,7 +185,7 @@ def test_youtube_video_search_tool(mock_adapter):
|
|||||||
mock_adapter.add.assert_called_once_with(
|
mock_adapter.add.assert_called_once_with(
|
||||||
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
||||||
)
|
)
|
||||||
mock_adapter.query.assert_called_once_with(search_query)
|
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
mock_adapter.query.reset_mock()
|
mock_adapter.query.reset_mock()
|
||||||
mock_adapter.add.reset_mock()
|
mock_adapter.add.reset_mock()
|
||||||
@@ -197,7 +197,7 @@ def test_youtube_video_search_tool(mock_adapter):
|
|||||||
mock_adapter.add.assert_called_once_with(
|
mock_adapter.add.assert_called_once_with(
|
||||||
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
||||||
)
|
)
|
||||||
mock_adapter.query.assert_called_once_with(search_query)
|
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
|
|
||||||
def test_youtube_channel_search_tool(mock_adapter):
|
def test_youtube_channel_search_tool(mock_adapter):
|
||||||
@@ -213,7 +213,7 @@ def test_youtube_channel_search_tool(mock_adapter):
|
|||||||
mock_adapter.add.assert_called_once_with(
|
mock_adapter.add.assert_called_once_with(
|
||||||
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
||||||
)
|
)
|
||||||
mock_adapter.query.assert_called_once_with(search_query)
|
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
mock_adapter.query.reset_mock()
|
mock_adapter.query.reset_mock()
|
||||||
mock_adapter.add.reset_mock()
|
mock_adapter.add.reset_mock()
|
||||||
@@ -227,7 +227,7 @@ def test_youtube_channel_search_tool(mock_adapter):
|
|||||||
mock_adapter.add.assert_called_once_with(
|
mock_adapter.add.assert_called_once_with(
|
||||||
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
||||||
)
|
)
|
||||||
mock_adapter.query.assert_called_once_with(search_query)
|
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
|
|
||||||
def test_code_docs_search_tool(mock_adapter):
|
def test_code_docs_search_tool(mock_adapter):
|
||||||
@@ -239,7 +239,7 @@ def test_code_docs_search_tool(mock_adapter):
|
|||||||
result = tool._run(search_query=search_query)
|
result = tool._run(search_query=search_query)
|
||||||
assert "test documentation" in result
|
assert "test documentation" in result
|
||||||
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
||||||
mock_adapter.query.assert_called_once_with(search_query)
|
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
mock_adapter.query.reset_mock()
|
mock_adapter.query.reset_mock()
|
||||||
mock_adapter.add.reset_mock()
|
mock_adapter.add.reset_mock()
|
||||||
@@ -248,7 +248,7 @@ def test_code_docs_search_tool(mock_adapter):
|
|||||||
result = tool._run(docs_url=docs_url, search_query=search_query)
|
result = tool._run(docs_url=docs_url, search_query=search_query)
|
||||||
assert "test documentation" in result
|
assert "test documentation" in result
|
||||||
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
||||||
mock_adapter.query.assert_called_once_with(search_query)
|
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
|
|
||||||
def test_github_search_tool(mock_adapter):
|
def test_github_search_tool(mock_adapter):
|
||||||
@@ -264,9 +264,11 @@ def test_github_search_tool(mock_adapter):
|
|||||||
result = tool._run(search_query="tell me about crewai repo")
|
result = tool._run(search_query="tell me about crewai repo")
|
||||||
assert "repo description" in result
|
assert "repo description" in result
|
||||||
mock_adapter.add.assert_called_once_with(
|
mock_adapter.add.assert_called_once_with(
|
||||||
"repo:crewai/crewai type:code", data_type="github", loader=ANY
|
"https://github.com/crewai/crewai",
|
||||||
|
data_type=DataType.GITHUB,
|
||||||
|
metadata={"content_types": ["code"], "gh_token": "test_token"}
|
||||||
)
|
)
|
||||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
# ensure content types provided by run call is used
|
# ensure content types provided by run call is used
|
||||||
mock_adapter.query.reset_mock()
|
mock_adapter.query.reset_mock()
|
||||||
@@ -280,9 +282,11 @@ def test_github_search_tool(mock_adapter):
|
|||||||
)
|
)
|
||||||
assert "repo description" in result
|
assert "repo description" in result
|
||||||
mock_adapter.add.assert_called_once_with(
|
mock_adapter.add.assert_called_once_with(
|
||||||
"repo:crewai/crewai type:code,issue", data_type="github", loader=ANY
|
"https://github.com/crewai/crewai",
|
||||||
|
data_type=DataType.GITHUB,
|
||||||
|
metadata={"content_types": ["code", "issue"], "gh_token": "test_token"}
|
||||||
)
|
)
|
||||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
# ensure default content types are used if not provided
|
# ensure default content types are used if not provided
|
||||||
mock_adapter.query.reset_mock()
|
mock_adapter.query.reset_mock()
|
||||||
@@ -295,9 +299,11 @@ def test_github_search_tool(mock_adapter):
|
|||||||
)
|
)
|
||||||
assert "repo description" in result
|
assert "repo description" in result
|
||||||
mock_adapter.add.assert_called_once_with(
|
mock_adapter.add.assert_called_once_with(
|
||||||
"repo:crewai/crewai type:code,repo,pr,issue", data_type="github", loader=ANY
|
"https://github.com/crewai/crewai",
|
||||||
|
data_type=DataType.GITHUB,
|
||||||
|
metadata={"content_types": ["code", "repo", "pr", "issue"], "gh_token": "test_token"}
|
||||||
)
|
)
|
||||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||||
|
|
||||||
# ensure nothing is added if no repo is provided
|
# ensure nothing is added if no repo is provided
|
||||||
mock_adapter.query.reset_mock()
|
mock_adapter.query.reset_mock()
|
||||||
@@ -306,4 +312,4 @@ def test_github_search_tool(mock_adapter):
|
|||||||
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
|
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
|
||||||
result = tool._run(search_query="tell me about crewai repo")
|
result = tool._run(search_query="tell me about crewai repo")
|
||||||
mock_adapter.add.assert_not_called()
|
mock_adapter.add.assert_not_called()
|
||||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||||
|
|||||||
Reference in New Issue
Block a user