feat: replace embedchain with native crewai adapter (#451)

- Remove embedchain adapter; add crewai rag adapter and update all search tools  
- Add loaders: pdf, youtube (video & channel), github, docs site, mysql, postgresql  
- Add configurable similarity threshold, limit params, and embedding_model support  
- Improve chromadb compatibility (sanitize metadata, convert columns, fix chunking)  
- Fix xml encoding, Python 3.10 issues, and youtube url spoofing  
- Update crewai dependency and instructions; refresh uv.lock  
- Update tests for new rag adapter and search params
This commit is contained in:
Greyson LaLonde
2025-09-18 19:02:22 -04:00
committed by GitHub
parent 8d9cee45f2
commit e29ca9ec28
33 changed files with 1317 additions and 277 deletions

View File

@@ -0,0 +1,215 @@
"""Adapter for CrewAI's native RAG system."""
from typing import Any, TypedDict, TypeAlias
from typing_extensions import Unpack
from pathlib import Path
import hashlib
from pydantic import Field, PrivateAttr
from crewai.rag.config.utils import get_rag_client
from crewai.rag.config.types import RagConfigType
from crewai.rag.types import BaseRecord, SearchResult
from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client
from crewai_tools.tools.rag.rag_tool import Adapter
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
ContentItem: TypeAlias = str | Path | dict[str, Any]
class AddDocumentParams(TypedDict, total=False):
"""Parameters for adding documents to the RAG system."""
data_type: DataType
metadata: dict[str, Any]
website: str
url: str
file_path: str | Path
github_url: str
youtube_url: str
directory_path: str | Path
class CrewAIRagAdapter(Adapter):
"""Adapter that uses CrewAI's native RAG system.
Supports custom vector database configuration through the config parameter.
"""
collection_name: str = "default"
summarize: bool = False
similarity_threshold: float = 0.6
limit: int = 5
config: RagConfigType | None = None
_client: BaseClient | None = PrivateAttr(default=None)
def model_post_init(self, __context: Any) -> None:
"""Initialize the CrewAI RAG client after model initialization."""
if self.config is not None:
self._client = create_client(self.config)
else:
self._client = get_rag_client()
self._client.get_or_create_collection(collection_name=self.collection_name)
def query(self, question: str, similarity_threshold: float | None = None, limit: int | None = None) -> str:
"""Query the knowledge base with a question.
Args:
question: The question to ask
similarity_threshold: Minimum similarity score for results (default: 0.6)
limit: Maximum number of results to return (default: 5)
Returns:
Relevant content from the knowledge base
"""
search_limit = limit if limit is not None else self.limit
search_threshold = similarity_threshold if similarity_threshold is not None else self.similarity_threshold
results: list[SearchResult] = self._client.search(
collection_name=self.collection_name,
query=question,
limit=search_limit,
score_threshold=search_threshold
)
if not results:
return "No relevant content found."
contents: list[str] = []
for result in results:
content: str = result.get("content", "")
if content:
contents.append(content)
return "\n\n".join(contents)
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
"""Add content to the knowledge base.
This method handles various input types and converts them to documents
for the vector database. It supports the data_type parameter for
compatibility with existing tools.
Args:
*args: Content items to add (strings, paths, or document dicts)
**kwargs: Additional parameters including data_type, metadata, etc.
"""
from crewai_tools.rag.data_types import DataTypes, DataType
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.rag.base_loader import LoaderResult
import os
documents: list[BaseRecord] = []
data_type: DataType | None = kwargs.get("data_type")
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
for arg in args:
source_ref: str
if isinstance(arg, dict):
source_ref = str(arg.get("source", arg.get("content", "")))
else:
source_ref = str(arg)
if not data_type:
data_type = DataTypes.from_content(source_ref)
if data_type == DataType.DIRECTORY:
if not os.path.isdir(source_ref):
raise ValueError(f"Directory does not exist: {source_ref}")
# Define binary and non-text file extensions to skip
binary_extensions = {'.pyc', '.pyo', '.png', '.jpg', '.jpeg', '.gif',
'.bmp', '.ico', '.svg', '.webp', '.pdf', '.zip',
'.tar', '.gz', '.bz2', '.7z', '.rar', '.exe',
'.dll', '.so', '.dylib', '.bin', '.dat', '.db',
'.sqlite', '.class', '.jar', '.war', '.ear'}
for root, dirs, files in os.walk(source_ref):
dirs[:] = [d for d in dirs if not d.startswith('.')]
for filename in files:
if filename.startswith('.'):
continue
# Skip binary files based on extension
file_ext = os.path.splitext(filename)[1].lower()
if file_ext in binary_extensions:
continue
# Skip __pycache__ directories
if '__pycache__' in root:
continue
file_path: str = os.path.join(root, filename)
try:
file_data_type: DataType = DataTypes.from_content(file_path)
file_loader = file_data_type.get_loader()
file_chunker = file_data_type.get_chunker()
file_source = SourceContent(file_path)
file_result: LoaderResult = file_loader.load(file_source)
file_chunks = file_chunker.chunk(file_result.content)
for chunk_idx, file_chunk in enumerate(file_chunks):
file_metadata: dict[str, Any] = base_metadata.copy()
file_metadata.update(file_result.metadata)
file_metadata["data_type"] = str(file_data_type)
file_metadata["file_path"] = file_path
file_metadata["chunk_index"] = chunk_idx
file_metadata["total_chunks"] = len(file_chunks)
if isinstance(arg, dict):
file_metadata.update(arg.get("metadata", {}))
chunk_id = hashlib.sha256(f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()).hexdigest()
documents.append({
"doc_id": chunk_id,
"content": file_chunk,
"metadata": sanitize_metadata_for_chromadb(file_metadata)
})
except Exception:
# Silently skip files that can't be processed
continue
else:
metadata: dict[str, Any] = base_metadata.copy()
if data_type in [DataType.PDF_FILE, DataType.TEXT_FILE, DataType.DOCX,
DataType.CSV, DataType.JSON, DataType.XML, DataType.MDX]:
if not os.path.isfile(source_ref):
raise FileNotFoundError(f"File does not exist: {source_ref}")
loader = data_type.get_loader()
chunker = data_type.get_chunker()
source_content = SourceContent(source_ref)
loader_result: LoaderResult = loader.load(source_content)
chunks = chunker.chunk(loader_result.content)
for i, chunk in enumerate(chunks):
chunk_metadata: dict[str, Any] = metadata.copy()
chunk_metadata.update(loader_result.metadata)
chunk_metadata["data_type"] = str(data_type)
chunk_metadata["chunk_index"] = i
chunk_metadata["total_chunks"] = len(chunks)
chunk_metadata["source"] = source_ref
if isinstance(arg, dict):
chunk_metadata.update(arg.get("metadata", {}))
chunk_id = hashlib.sha256(f"{loader_result.doc_id}_{i}_{chunk}".encode()).hexdigest()
documents.append({
"doc_id": chunk_id,
"content": chunk,
"metadata": sanitize_metadata_for_chromadb(chunk_metadata)
})
if documents:
self._client.add_documents(
collection_name=self.collection_name,
documents=documents
)

View File

@@ -1,34 +0,0 @@
from typing import Any
from crewai_tools.tools.rag.rag_tool import Adapter
try:
from embedchain import App
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
class EmbedchainAdapter(Adapter):
embedchain_app: Any # Will be App when embedchain is available
summarize: bool = False
def __init__(self, **data):
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().__init__(**data)
def query(self, question: str) -> str:
result, sources = self.embedchain_app.query(
question, citations=True, dry_run=(not self.summarize)
)
if self.summarize:
return result
return "\n\n".join([source[0] for source in sources])
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.embedchain_app.add(*args, **kwargs)

View File

@@ -1,41 +0,0 @@
from typing import Any, Optional
from crewai_tools.tools.rag.rag_tool import Adapter
try:
from embedchain import App
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
class PDFEmbedchainAdapter(Adapter):
embedchain_app: Any # Will be App when embedchain is available
summarize: bool = False
src: Optional[str] = None
def __init__(self, **data):
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().__init__(**data)
def query(self, question: str) -> str:
where = (
{"app_id": self.embedchain_app.config.id, "source": self.src}
if self.src
else None
)
result, sources = self.embedchain_app.query(
question, citations=True, dry_run=(not self.summarize), where=where
)
if self.summarize:
return result
return "\n\n".join([source[0] for source in sources])
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.src = args[0] if args else None
self.embedchain_app.add(*args, **kwargs)

View File

@@ -112,7 +112,10 @@ class RecursiveCharacterTextSplitter:
if separator == "": if separator == "":
doc = "".join(current_doc) doc = "".join(current_doc)
else: else:
doc = separator.join(current_doc) if self._keep_separator and separator == " ":
doc = "".join(current_doc)
else:
doc = separator.join(current_doc)
if doc: if doc:
docs.append(doc) docs.append(doc)
@@ -133,7 +136,10 @@ class RecursiveCharacterTextSplitter:
if separator == "": if separator == "":
doc = "".join(current_doc) doc = "".join(current_doc)
else: else:
doc = separator.join(current_doc) if self._keep_separator and separator == " ":
doc = "".join(current_doc)
else:
doc = separator.join(current_doc)
if doc: if doc:
docs.append(doc) docs.append(doc)

View File

@@ -25,6 +25,8 @@ class DataType(str, Enum):
# Web types # Web types
WEBSITE = "website" WEBSITE = "website"
DOCS_SITE = "docs_site" DOCS_SITE = "docs_site"
YOUTUBE_VIDEO = "youtube_video"
YOUTUBE_CHANNEL = "youtube_channel"
# Raw types # Raw types
TEXT = "text" TEXT = "text"
@@ -34,6 +36,7 @@ class DataType(str, Enum):
from importlib import import_module from importlib import import_module
chunkers = { chunkers = {
DataType.PDF_FILE: ("text_chunker", "TextChunker"),
DataType.TEXT_FILE: ("text_chunker", "TextChunker"), DataType.TEXT_FILE: ("text_chunker", "TextChunker"),
DataType.TEXT: ("text_chunker", "TextChunker"), DataType.TEXT: ("text_chunker", "TextChunker"),
DataType.DOCX: ("text_chunker", "DocxChunker"), DataType.DOCX: ("text_chunker", "DocxChunker"),
@@ -45,9 +48,18 @@ class DataType(str, Enum):
DataType.XML: ("structured_chunker", "XmlChunker"), DataType.XML: ("structured_chunker", "XmlChunker"),
DataType.WEBSITE: ("web_chunker", "WebsiteChunker"), DataType.WEBSITE: ("web_chunker", "WebsiteChunker"),
DataType.DIRECTORY: ("text_chunker", "TextChunker"),
DataType.YOUTUBE_VIDEO: ("text_chunker", "TextChunker"),
DataType.YOUTUBE_CHANNEL: ("text_chunker", "TextChunker"),
DataType.GITHUB: ("text_chunker", "TextChunker"),
DataType.DOCS_SITE: ("text_chunker", "TextChunker"),
DataType.MYSQL: ("text_chunker", "TextChunker"),
DataType.POSTGRES: ("text_chunker", "TextChunker"),
} }
module_name, class_name = chunkers.get(self, ("default_chunker", "DefaultChunker")) if self not in chunkers:
raise ValueError(f"No chunker defined for {self}")
module_name, class_name = chunkers[self]
module_path = f"crewai_tools.rag.chunkers.{module_name}" module_path = f"crewai_tools.rag.chunkers.{module_name}"
try: try:
@@ -60,6 +72,7 @@ class DataType(str, Enum):
from importlib import import_module from importlib import import_module
loaders = { loaders = {
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"), DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
DataType.TEXT: ("text_loader", "TextLoader"), DataType.TEXT: ("text_loader", "TextLoader"),
DataType.XML: ("xml_loader", "XMLLoader"), DataType.XML: ("xml_loader", "XMLLoader"),
@@ -69,9 +82,17 @@ class DataType(str, Enum):
DataType.DOCX: ("docx_loader", "DOCXLoader"), DataType.DOCX: ("docx_loader", "DOCXLoader"),
DataType.CSV: ("csv_loader", "CSVLoader"), DataType.CSV: ("csv_loader", "CSVLoader"),
DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"), DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"),
DataType.YOUTUBE_VIDEO: ("youtube_video_loader", "YoutubeVideoLoader"),
DataType.YOUTUBE_CHANNEL: ("youtube_channel_loader", "YoutubeChannelLoader"),
DataType.GITHUB: ("github_loader", "GithubLoader"),
DataType.DOCS_SITE: ("docs_site_loader", "DocsSiteLoader"),
DataType.MYSQL: ("mysql_loader", "MySQLLoader"),
DataType.POSTGRES: ("postgres_loader", "PostgresLoader"),
} }
module_name, class_name = loaders.get(self, ("text_loader", "TextLoader")) if self not in loaders:
raise ValueError(f"No loader defined for {self}")
module_name, class_name = loaders[self]
module_path = f"crewai_tools.rag.loaders.{module_name}" module_path = f"crewai_tools.rag.loaders.{module_name}"
try: try:
module = import_module(module_path) module = import_module(module_path)

View File

@@ -6,6 +6,9 @@ from crewai_tools.rag.loaders.json_loader import JSONLoader
from crewai_tools.rag.loaders.docx_loader import DOCXLoader from crewai_tools.rag.loaders.docx_loader import DOCXLoader
from crewai_tools.rag.loaders.csv_loader import CSVLoader from crewai_tools.rag.loaders.csv_loader import CSVLoader
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
from crewai_tools.rag.loaders.pdf_loader import PDFLoader
from crewai_tools.rag.loaders.youtube_video_loader import YoutubeVideoLoader
from crewai_tools.rag.loaders.youtube_channel_loader import YoutubeChannelLoader
__all__ = [ __all__ = [
"TextFileLoader", "TextFileLoader",
@@ -17,4 +20,7 @@ __all__ = [
"DOCXLoader", "DOCXLoader",
"CSVLoader", "CSVLoader",
"DirectoryLoader", "DirectoryLoader",
"PDFLoader",
"YoutubeVideoLoader",
"YoutubeChannelLoader",
] ]

View File

@@ -0,0 +1,98 @@
"""Documentation site loader."""
from typing import Any
from urllib.parse import urljoin, urlparse
import requests
from bs4 import BeautifulSoup
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class DocsSiteLoader(BaseLoader):
"""Loader for documentation websites."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load content from a documentation site.
Args:
source: Documentation site URL
**kwargs: Additional arguments
Returns:
LoaderResult with documentation content
"""
docs_url = source.source
try:
response = requests.get(docs_url, timeout=30)
response.raise_for_status()
except requests.RequestException as e:
raise ValueError(f"Unable to fetch documentation from {docs_url}: {e}")
soup = BeautifulSoup(response.text, "html.parser")
for script in soup(["script", "style"]):
script.decompose()
title = soup.find("title")
title_text = title.get_text(strip=True) if title else "Documentation"
main_content = None
for selector in ["main", "article", '[role="main"]', ".content", "#content", ".documentation"]:
main_content = soup.select_one(selector)
if main_content:
break
if not main_content:
main_content = soup.find("body")
if not main_content:
raise ValueError(f"Unable to extract content from documentation site: {docs_url}")
text_parts = [f"Title: {title_text}", ""]
headings = main_content.find_all(["h1", "h2", "h3"])
if headings:
text_parts.append("Table of Contents:")
for heading in headings[:15]:
level = int(heading.name[1])
indent = " " * (level - 1)
text_parts.append(f"{indent}- {heading.get_text(strip=True)}")
text_parts.append("")
text = main_content.get_text(separator="\n", strip=True)
lines = [line.strip() for line in text.split("\n") if line.strip()]
text_parts.extend(lines)
nav_links = []
for nav_selector in ["nav", ".sidebar", ".toc", ".navigation"]:
nav = soup.select_one(nav_selector)
if nav:
links = nav.find_all("a", href=True)
for link in links[:20]:
href = link["href"]
if not href.startswith(("http://", "https://", "mailto:", "#")):
full_url = urljoin(docs_url, href)
nav_links.append(f"- {link.get_text(strip=True)}: {full_url}")
if nav_links:
text_parts.append("")
text_parts.append("Related documentation pages:")
text_parts.extend(nav_links[:10])
content = "\n".join(text_parts)
if len(content) > 100000:
content = content[:100000] + "\n\n[Content truncated...]"
return LoaderResult(
content=content,
metadata={
"source": docs_url,
"title": title_text,
"domain": urlparse(docs_url).netloc
},
doc_id=self.generate_doc_id(source_ref=docs_url, content=content)
)

View File

@@ -0,0 +1,110 @@
"""GitHub repository content loader."""
from typing import Any
from github import Github, GithubException
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class GithubLoader(BaseLoader):
"""Loader for GitHub repository content."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load content from a GitHub repository.
Args:
source: GitHub repository URL
**kwargs: Additional arguments including gh_token and content_types
Returns:
LoaderResult with repository content
"""
metadata = kwargs.get("metadata", {})
gh_token = metadata.get("gh_token")
content_types = metadata.get("content_types", ["code", "repo"])
repo_url = source.source
if not repo_url.startswith("https://github.com/"):
raise ValueError(f"Invalid GitHub URL: {repo_url}")
parts = repo_url.replace("https://github.com/", "").strip("/").split("/")
if len(parts) < 2:
raise ValueError(f"Invalid GitHub repository URL: {repo_url}")
repo_name = f"{parts[0]}/{parts[1]}"
g = Github(gh_token) if gh_token else Github()
try:
repo = g.get_repo(repo_name)
except GithubException as e:
raise ValueError(f"Unable to access repository {repo_name}: {e}")
all_content = []
if "repo" in content_types:
all_content.append(f"Repository: {repo.full_name}")
all_content.append(f"Description: {repo.description or 'No description'}")
all_content.append(f"Language: {repo.language or 'Not specified'}")
all_content.append(f"Stars: {repo.stargazers_count}")
all_content.append(f"Forks: {repo.forks_count}")
all_content.append("")
if "code" in content_types:
try:
readme = repo.get_readme()
all_content.append("README:")
all_content.append(readme.decoded_content.decode("utf-8", errors="ignore"))
all_content.append("")
except GithubException:
pass
try:
contents = repo.get_contents("")
if isinstance(contents, list):
all_content.append("Repository structure:")
for content_file in contents[:20]:
all_content.append(f"- {content_file.path} ({content_file.type})")
all_content.append("")
except GithubException:
pass
if "pr" in content_types:
prs = repo.get_pulls(state="open")
pr_list = list(prs[:5])
if pr_list:
all_content.append("Recent Pull Requests:")
for pr in pr_list:
all_content.append(f"- PR #{pr.number}: {pr.title}")
if pr.body:
body_preview = pr.body[:200].replace("\n", " ")
all_content.append(f" {body_preview}")
all_content.append("")
if "issue" in content_types:
issues = repo.get_issues(state="open")
issue_list = [i for i in list(issues[:10]) if not i.pull_request][:5]
if issue_list:
all_content.append("Recent Issues:")
for issue in issue_list:
all_content.append(f"- Issue #{issue.number}: {issue.title}")
if issue.body:
body_preview = issue.body[:200].replace("\n", " ")
all_content.append(f" {body_preview}")
all_content.append("")
if not all_content:
raise ValueError(f"No content could be loaded from repository: {repo_url}")
content = "\n".join(all_content)
return LoaderResult(
content=content,
metadata={
"source": repo_url,
"repo": repo_name,
"content_types": content_types
},
doc_id=self.generate_doc_id(source_ref=repo_url, content=content)
)

View File

@@ -0,0 +1,99 @@
"""MySQL database loader."""
from typing import Any
from urllib.parse import urlparse
import pymysql
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class MySQLLoader(BaseLoader):
"""Loader for MySQL database content."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load content from a MySQL database table.
Args:
source: SQL query (e.g., "SELECT * FROM table_name")
**kwargs: Additional arguments including db_uri
Returns:
LoaderResult with database content
"""
metadata = kwargs.get("metadata", {})
db_uri = metadata.get("db_uri")
if not db_uri:
raise ValueError("Database URI is required for MySQL loader")
query = source.source
parsed = urlparse(db_uri)
if parsed.scheme not in ["mysql", "mysql+pymysql"]:
raise ValueError(f"Invalid MySQL URI scheme: {parsed.scheme}")
connection_params = {
"host": parsed.hostname or "localhost",
"port": parsed.port or 3306,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path.lstrip("/") if parsed.path else None,
"charset": "utf8mb4",
"cursorclass": pymysql.cursors.DictCursor
}
if not connection_params["database"]:
raise ValueError("Database name is required in the URI")
try:
connection = pymysql.connect(**connection_params)
try:
with connection.cursor() as cursor:
cursor.execute(query)
rows = cursor.fetchall()
if not rows:
content = "No data found in the table"
return LoaderResult(
content=content,
metadata={"source": query, "row_count": 0},
doc_id=self.generate_doc_id(source_ref=query, content=content)
)
text_parts = []
columns = list(rows[0].keys())
text_parts.append(f"Columns: {', '.join(columns)}")
text_parts.append(f"Total rows: {len(rows)}")
text_parts.append("")
for i, row in enumerate(rows, 1):
text_parts.append(f"Row {i}:")
for col, val in row.items():
if val is not None:
text_parts.append(f" {col}: {val}")
text_parts.append("")
content = "\n".join(text_parts)
if len(content) > 100000:
content = content[:100000] + "\n\n[Content truncated...]"
return LoaderResult(
content=content,
metadata={
"source": query,
"database": connection_params["database"],
"row_count": len(rows),
"columns": columns
},
doc_id=self.generate_doc_id(source_ref=query, content=content)
)
finally:
connection.close()
except pymysql.Error as e:
raise ValueError(f"MySQL database error: {e}")
except Exception as e:
raise ValueError(f"Failed to load data from MySQL: {e}")

View File

@@ -0,0 +1,72 @@
"""PDF loader for extracting text from PDF files."""
import os
from pathlib import Path
from typing import Any
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class PDFLoader(BaseLoader):
"""Loader for PDF files."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load and extract text from a PDF file.
Args:
source: The source content containing the PDF file path
Returns:
LoaderResult with extracted text content
Raises:
FileNotFoundError: If the PDF file doesn't exist
ImportError: If required PDF libraries aren't installed
"""
try:
import pypdf
except ImportError:
try:
import PyPDF2 as pypdf
except ImportError:
raise ImportError(
"PDF support requires pypdf or PyPDF2. "
"Install with: uv add pypdf"
)
file_path = source.source
if not os.path.isfile(file_path):
raise FileNotFoundError(f"PDF file not found: {file_path}")
text_content = []
metadata: dict[str, Any] = {
"source": str(file_path),
"file_name": Path(file_path).name,
"file_type": "pdf"
}
try:
with open(file_path, 'rb') as file:
pdf_reader = pypdf.PdfReader(file)
metadata["num_pages"] = len(pdf_reader.pages)
for page_num, page in enumerate(pdf_reader.pages, 1):
page_text = page.extract_text()
if page_text.strip():
text_content.append(f"Page {page_num}:\n{page_text}")
except Exception as e:
raise ValueError(f"Error reading PDF file {file_path}: {str(e)}")
if not text_content:
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
else:
content = "\n\n".join(text_content)
return LoaderResult(
content=content,
source=str(file_path),
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content)
)

View File

@@ -0,0 +1,99 @@
"""PostgreSQL database loader."""
from typing import Any
from urllib.parse import urlparse
import psycopg2
from psycopg2.extras import RealDictCursor
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class PostgresLoader(BaseLoader):
"""Loader for PostgreSQL database content."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load content from a PostgreSQL database table.
Args:
source: SQL query (e.g., "SELECT * FROM table_name")
**kwargs: Additional arguments including db_uri
Returns:
LoaderResult with database content
"""
metadata = kwargs.get("metadata", {})
db_uri = metadata.get("db_uri")
if not db_uri:
raise ValueError("Database URI is required for PostgreSQL loader")
query = source.source
parsed = urlparse(db_uri)
if parsed.scheme not in ["postgresql", "postgres", "postgresql+psycopg2"]:
raise ValueError(f"Invalid PostgreSQL URI scheme: {parsed.scheme}")
connection_params = {
"host": parsed.hostname or "localhost",
"port": parsed.port or 5432,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path.lstrip("/") if parsed.path else None,
"cursor_factory": RealDictCursor
}
if not connection_params["database"]:
raise ValueError("Database name is required in the URI")
try:
connection = psycopg2.connect(**connection_params)
try:
with connection.cursor() as cursor:
cursor.execute(query)
rows = cursor.fetchall()
if not rows:
content = "No data found in the table"
return LoaderResult(
content=content,
metadata={"source": query, "row_count": 0},
doc_id=self.generate_doc_id(source_ref=query, content=content)
)
text_parts = []
columns = list(rows[0].keys())
text_parts.append(f"Columns: {', '.join(columns)}")
text_parts.append(f"Total rows: {len(rows)}")
text_parts.append("")
for i, row in enumerate(rows, 1):
text_parts.append(f"Row {i}:")
for col, val in row.items():
if val is not None:
text_parts.append(f" {col}: {val}")
text_parts.append("")
content = "\n".join(text_parts)
if len(content) > 100000:
content = content[:100000] + "\n\n[Content truncated...]"
return LoaderResult(
content=content,
metadata={
"source": query,
"database": connection_params["database"],
"row_count": len(rows),
"columns": columns
},
doc_id=self.generate_doc_id(source_ref=query, content=content)
)
finally:
connection.close()
except psycopg2.Error as e:
raise ValueError(f"PostgreSQL database error: {e}")
except Exception as e:
raise ValueError(f"Failed to load data from PostgreSQL: {e}")

View File

@@ -11,7 +11,7 @@ class XMLLoader(BaseLoader):
if source_content.is_url(): if source_content.is_url():
content = self._load_from_url(source_ref, kwargs) content = self._load_from_url(source_ref, kwargs)
elif os.path.exists(source_ref): elif source_content.path_exists():
content = self._load_from_file(source_ref) content = self._load_from_file(source_ref)
return self._parse_xml(content, source_ref) return self._parse_xml(content, source_ref)

View File

@@ -0,0 +1,141 @@
"""YouTube channel loader for extracting content from YouTube channels."""
import re
from typing import Any
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class YoutubeChannelLoader(BaseLoader):
"""Loader for YouTube channels."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load and extract content from a YouTube channel.
Args:
source: The source content containing the YouTube channel URL
Returns:
LoaderResult with channel content
Raises:
ImportError: If required YouTube libraries aren't installed
ValueError: If the URL is not a valid YouTube channel URL
"""
try:
from pytube import Channel
except ImportError:
raise ImportError(
"YouTube channel support requires pytube. "
"Install with: uv add pytube"
)
channel_url = source.source
if not any(pattern in channel_url for pattern in ['youtube.com/channel/', 'youtube.com/c/', 'youtube.com/@', 'youtube.com/user/']):
raise ValueError(f"Invalid YouTube channel URL: {channel_url}")
metadata: dict[str, Any] = {
"source": channel_url,
"data_type": "youtube_channel"
}
try:
channel = Channel(channel_url)
metadata["channel_name"] = channel.channel_name
metadata["channel_id"] = channel.channel_id
max_videos = kwargs.get('max_videos', 10)
video_urls = list(channel.video_urls)[:max_videos]
metadata["num_videos_loaded"] = len(video_urls)
metadata["total_videos"] = len(list(channel.video_urls))
content_parts = [
f"YouTube Channel: {channel.channel_name}",
f"Channel ID: {channel.channel_id}",
f"Total Videos: {metadata['total_videos']}",
f"Videos Loaded: {metadata['num_videos_loaded']}",
"\n--- Video Summaries ---\n"
]
try:
from youtube_transcript_api import YouTubeTranscriptApi
from pytube import YouTube
for i, video_url in enumerate(video_urls, 1):
try:
video_id = self._extract_video_id(video_url)
if not video_id:
continue
yt = YouTube(video_url)
title = yt.title or f"Video {i}"
description = yt.description[:200] if yt.description else "No description"
content_parts.append(f"\n{i}. {title}")
content_parts.append(f" URL: {video_url}")
content_parts.append(f" Description: {description}...")
try:
api = YouTubeTranscriptApi()
transcript_list = api.list(video_id)
transcript = None
try:
transcript = transcript_list.find_transcript(['en'])
except:
try:
transcript = transcript_list.find_generated_transcript(['en'])
except:
transcript = next(iter(transcript_list), None)
if transcript:
transcript_data = transcript.fetch()
text_parts = []
char_count = 0
for entry in transcript_data:
text = entry.text.strip() if hasattr(entry, 'text') else ''
if text:
text_parts.append(text)
char_count += len(text)
if char_count > 500:
break
if text_parts:
preview = ' '.join(text_parts)[:500]
content_parts.append(f" Transcript Preview: {preview}...")
except:
content_parts.append(" Transcript: Not available")
except Exception as e:
content_parts.append(f"\n{i}. Error loading video: {str(e)}")
except ImportError:
for i, video_url in enumerate(video_urls, 1):
content_parts.append(f"\n{i}. {video_url}")
content = '\n'.join(content_parts)
except Exception as e:
raise ValueError(f"Unable to load YouTube channel {channel_url}: {str(e)}") from e
return LoaderResult(
content=content,
source=channel_url,
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=channel_url, content=content)
)
def _extract_video_id(self, url: str) -> str | None:
"""Extract video ID from YouTube URL."""
patterns = [
r'(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)',
]
for pattern in patterns:
match = re.search(pattern, url)
if match:
return match.group(1)
return None

View File

@@ -0,0 +1,123 @@
"""YouTube video loader for extracting transcripts from YouTube videos."""
import re
from typing import Any
from urllib.parse import urlparse, parse_qs
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
class YoutubeVideoLoader(BaseLoader):
"""Loader for YouTube videos."""
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
"""Load and extract transcript from a YouTube video.
Args:
source: The source content containing the YouTube URL
Returns:
LoaderResult with transcript content
Raises:
ImportError: If required YouTube libraries aren't installed
ValueError: If the URL is not a valid YouTube video URL
"""
try:
from youtube_transcript_api import YouTubeTranscriptApi
except ImportError:
raise ImportError(
"YouTube support requires youtube-transcript-api. "
"Install with: uv add youtube-transcript-api"
)
video_url = source.source
video_id = self._extract_video_id(video_url)
if not video_id:
raise ValueError(f"Invalid YouTube URL: {video_url}")
metadata: dict[str, Any] = {
"source": video_url,
"video_id": video_id,
"data_type": "youtube_video"
}
try:
api = YouTubeTranscriptApi()
transcript_list = api.list(video_id)
transcript = None
try:
transcript = transcript_list.find_transcript(['en'])
except:
try:
transcript = transcript_list.find_generated_transcript(['en'])
except:
transcript = next(iter(transcript_list))
if transcript:
metadata["language"] = transcript.language
metadata["is_generated"] = transcript.is_generated
transcript_data = transcript.fetch()
text_content = []
for entry in transcript_data:
text = entry.text.strip() if hasattr(entry, 'text') else ''
if text:
text_content.append(text)
content = ' '.join(text_content)
try:
from pytube import YouTube
yt = YouTube(video_url)
metadata["title"] = yt.title
metadata["author"] = yt.author
metadata["length_seconds"] = yt.length
metadata["description"] = yt.description[:500] if yt.description else None
if yt.title:
content = f"Title: {yt.title}\n\nAuthor: {yt.author or 'Unknown'}\n\nTranscript:\n{content}"
except:
pass
else:
raise ValueError(f"No transcript available for YouTube video: {video_id}")
except Exception as e:
raise ValueError(f"Unable to extract transcript from YouTube video {video_id}: {str(e)}") from e
return LoaderResult(
content=content,
source=video_url,
metadata=metadata,
doc_id=self.generate_doc_id(source_ref=video_url, content=content)
)
def _extract_video_id(self, url: str) -> str | None:
"""Extract video ID from various YouTube URL formats."""
patterns = [
r'(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)',
]
for pattern in patterns:
match = re.search(pattern, url)
if match:
return match.group(1)
try:
parsed = urlparse(url)
hostname = parsed.hostname
if hostname:
hostname_lower = hostname.lower()
# Allow youtube.com and any subdomain of youtube.com, plus youtu.be shortener
if hostname_lower == 'youtube.com' or hostname_lower.endswith('.youtube.com') or hostname_lower == 'youtu.be':
query_params = parse_qs(parsed.query)
if 'v' in query_params:
return query_params['v'][0]
except:
pass
return None

View File

@@ -1,4 +1,29 @@
import hashlib import hashlib
from typing import Any
def compute_sha256(content: str) -> str: def compute_sha256(content: str) -> str:
return hashlib.sha256(content.encode("utf-8")).hexdigest() return hashlib.sha256(content.encode("utf-8")).hexdigest()
def sanitize_metadata_for_chromadb(metadata: dict[str, Any]) -> dict[str, Any]:
"""Sanitize metadata to ensure ChromaDB compatibility.
ChromaDB only accepts str, int, float, or bool values in metadata.
This function converts other types to strings.
Args:
metadata: Dictionary of metadata to sanitize
Returns:
Sanitized metadata dictionary with only ChromaDB-compatible types
"""
sanitized = {}
for key, value in metadata.items():
if isinstance(value, (str, int, float, bool)) or value is None:
sanitized[key] = value
elif isinstance(value, (list, tuple)):
# Convert lists/tuples to pipe-separated strings
sanitized[key] = " | ".join(str(v) for v in value)
else:
# Convert other types to string
sanitized[key] = str(value)
return sanitized

View File

@@ -1,14 +1,10 @@
from typing import Any, Optional, Type from typing import Any, Optional, Type
try:
from embedchain.models.data_type import DataType
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class FixedCodeDocsSearchToolSchema(BaseModel): class FixedCodeDocsSearchToolSchema(BaseModel):
@@ -42,15 +38,15 @@ class CodeDocsSearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, docs_url: str) -> None: def add(self, docs_url: str) -> None:
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().add(docs_url, data_type=DataType.DOCS_SITE) super().add(docs_url, data_type=DataType.DOCS_SITE)
def _run( def _run(
self, self,
search_query: str, search_query: str,
docs_url: Optional[str] = None, docs_url: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if docs_url is not None: if docs_url is not None:
self.add(docs_url) self.add(docs_url)
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -1,14 +1,10 @@
from typing import Optional, Type from typing import Optional, Type
try:
from embedchain.models.data_type import DataType
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class FixedCSVSearchToolSchema(BaseModel): class FixedCSVSearchToolSchema(BaseModel):
@@ -42,15 +38,16 @@ class CSVSearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, csv: str) -> None: def add(self, csv: str) -> None:
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().add(csv, data_type=DataType.CSV) super().add(csv, data_type=DataType.CSV)
def _run( def _run(
self, self,
search_query: str, search_query: str,
csv: Optional[str] = None, csv: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if csv is not None: if csv is not None:
self.add(csv) self.add(csv)
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -1,14 +1,9 @@
from typing import Optional, Type from typing import Optional, Type
try:
from embedchain.loaders.directory_loader import DirectoryLoader
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class FixedDirectorySearchToolSchema(BaseModel): class FixedDirectorySearchToolSchema(BaseModel):
@@ -34,8 +29,6 @@ class DirectorySearchTool(RagTool):
args_schema: Type[BaseModel] = DirectorySearchToolSchema args_schema: Type[BaseModel] = DirectorySearchToolSchema
def __init__(self, directory: Optional[str] = None, **kwargs): def __init__(self, directory: Optional[str] = None, **kwargs):
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().__init__(**kwargs) super().__init__(**kwargs)
if directory is not None: if directory is not None:
self.add(directory) self.add(directory)
@@ -44,16 +37,15 @@ class DirectorySearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, directory: str) -> None: def add(self, directory: str) -> None:
super().add( super().add(directory, data_type=DataType.DIRECTORY)
directory,
loader=DirectoryLoader(config=dict(recursive=True)),
)
def _run( def _run(
self, self,
search_query: str, search_query: str,
directory: Optional[str] = None, directory: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if directory is not None: if directory is not None:
self.add(directory) self.add(directory)
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -1,14 +1,10 @@
from typing import Any, Optional, Type from typing import Any, Optional, Type
try:
from embedchain.models.data_type import DataType
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class FixedDOCXSearchToolSchema(BaseModel): class FixedDOCXSearchToolSchema(BaseModel):
@@ -48,15 +44,15 @@ class DOCXSearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, docx: str) -> None: def add(self, docx: str) -> None:
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().add(docx, data_type=DataType.DOCX) super().add(docx, data_type=DataType.DOCX)
def _run( def _run(
self, self,
search_query: str, search_query: str,
docx: Optional[str] = None, docx: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> Any: ) -> Any:
if docx is not None: if docx is not None:
self.add(docx) self.add(docx)
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -1,14 +1,9 @@
from typing import List, Optional, Type, Any from typing import List, Optional, Type
try: from pydantic import BaseModel, Field
from embedchain.loaders.github import GithubLoader
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field, PrivateAttr
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class FixedGithubSearchToolSchema(BaseModel): class FixedGithubSearchToolSchema(BaseModel):
@@ -42,7 +37,6 @@ class GithubSearchTool(RagTool):
default_factory=lambda: ["code", "repo", "pr", "issue"], default_factory=lambda: ["code", "repo", "pr", "issue"],
description="Content types you want to be included search, options: [code, repo, pr, issue]", description="Content types you want to be included search, options: [code, repo, pr, issue]",
) )
_loader: Any | None = PrivateAttr(default=None)
def __init__( def __init__(
self, self,
@@ -50,10 +44,7 @@ class GithubSearchTool(RagTool):
content_types: Optional[List[str]] = None, content_types: Optional[List[str]] = None,
**kwargs, **kwargs,
): ):
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().__init__(**kwargs) super().__init__(**kwargs)
self._loader = GithubLoader(config={"token": self.gh_token})
if github_repo and content_types: if github_repo and content_types:
self.add(repo=github_repo, content_types=content_types) self.add(repo=github_repo, content_types=content_types)
@@ -67,11 +58,10 @@ class GithubSearchTool(RagTool):
content_types: Optional[List[str]] = None, content_types: Optional[List[str]] = None,
) -> None: ) -> None:
content_types = content_types or self.content_types content_types = content_types or self.content_types
super().add( super().add(
f"repo:{repo} type:{','.join(content_types)}", f"https://github.com/{repo}",
data_type="github", data_type=DataType.GITHUB,
loader=self._loader, metadata={"content_types": content_types, "gh_token": self.gh_token}
) )
def _run( def _run(
@@ -79,10 +69,12 @@ class GithubSearchTool(RagTool):
search_query: str, search_query: str,
github_repo: Optional[str] = None, github_repo: Optional[str] = None,
content_types: Optional[List[str]] = None, content_types: Optional[List[str]] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if github_repo: if github_repo:
self.add( self.add(
repo=github_repo, repo=github_repo,
content_types=content_types, content_types=content_types,
) )
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -41,7 +41,9 @@ class JSONSearchTool(RagTool):
self, self,
search_query: str, search_query: str,
json_path: Optional[str] = None, json_path: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if json_path is not None: if json_path is not None:
self.add(json_path) self.add(json_path)
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -2,13 +2,9 @@ from typing import Optional, Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
try:
from embedchain.models.data_type import DataType
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class FixedMDXSearchToolSchema(BaseModel): class FixedMDXSearchToolSchema(BaseModel):
@@ -42,15 +38,15 @@ class MDXSearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, mdx: str) -> None: def add(self, mdx: str) -> None:
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().add(mdx, data_type=DataType.MDX) super().add(mdx, data_type=DataType.MDX)
def _run( def _run(
self, self,
search_query: str, search_query: str,
mdx: Optional[str] = None, mdx: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if mdx is not None: if mdx is not None:
self.add(mdx) self.add(mdx)
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -1,14 +1,9 @@
from typing import Any, Type from typing import Any, Type
try:
from embedchain.loaders.mysql import MySQLLoader
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class MySQLSearchToolSchema(BaseModel): class MySQLSearchToolSchema(BaseModel):
@@ -27,12 +22,8 @@ class MySQLSearchTool(RagTool):
db_uri: str = Field(..., description="Mandatory database URI") db_uri: str = Field(..., description="Mandatory database URI")
def __init__(self, table_name: str, **kwargs): def __init__(self, table_name: str, **kwargs):
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().__init__(**kwargs) super().__init__(**kwargs)
kwargs["data_type"] = "mysql" self.add(table_name, data_type=DataType.MYSQL, metadata={"db_uri": self.db_uri})
kwargs["loader"] = MySQLLoader(config=dict(url=self.db_uri))
self.add(table_name)
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content." self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
self._generate_description() self._generate_description()
@@ -46,6 +37,8 @@ class MySQLSearchTool(RagTool):
def _run( def _run(
self, self,
search_query: str, search_query: str,
similarity_threshold: float | None = None,
limit: int | None = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -2,13 +2,8 @@ from typing import Optional, Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
try:
from embedchain.models.data_type import DataType
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class FixedPDFSearchToolSchema(BaseModel): class FixedPDFSearchToolSchema(BaseModel):
@@ -41,15 +36,15 @@ class PDFSearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, pdf: str) -> None: def add(self, pdf: str) -> None:
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().add(pdf, data_type=DataType.PDF_FILE) super().add(pdf, data_type=DataType.PDF_FILE)
def _run( def _run(
self, self,
query: str, query: str,
pdf: Optional[str] = None, pdf: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if pdf is not None: if pdf is not None:
self.add(pdf) self.add(pdf)
return super()._run(query=query) return super()._run(query=query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -1,14 +1,9 @@
from typing import Any, Type from typing import Any, Type
try:
from embedchain.loaders.postgres import PostgresLoader
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class PGSearchToolSchema(BaseModel): class PGSearchToolSchema(BaseModel):
@@ -27,12 +22,8 @@ class PGSearchTool(RagTool):
db_uri: str = Field(..., description="Mandatory database URI") db_uri: str = Field(..., description="Mandatory database URI")
def __init__(self, table_name: str, **kwargs): def __init__(self, table_name: str, **kwargs):
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().__init__(**kwargs) super().__init__(**kwargs)
kwargs["data_type"] = "postgres" self.add(table_name, data_type=DataType.POSTGRES, metadata={"db_uri": self.db_uri})
kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
self.add(table_name)
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content." self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
self._generate_description() self._generate_description()
@@ -46,6 +37,8 @@ class PGSearchTool(RagTool):
def _run( def _run(
self, self,
search_query: str, search_query: str,
similarity_threshold: float | None = None,
limit: int | None = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
return super()._run(query=search_query, **kwargs) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit, **kwargs)

View File

@@ -1,17 +1,22 @@
import portalocker import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any, cast
from pydantic import BaseModel, ConfigDict, Field, model_validator
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, model_validator
class Adapter(BaseModel, ABC): class Adapter(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod @abstractmethod
def query(self, question: str) -> str: def query(
self,
question: str,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str:
"""Query the knowledge base with a question and return the answer.""" """Query the knowledge base with a question and return the answer."""
@abstractmethod @abstractmethod
@@ -25,7 +30,12 @@ class Adapter(BaseModel, ABC):
class RagTool(BaseTool): class RagTool(BaseTool):
class _AdapterPlaceholder(Adapter): class _AdapterPlaceholder(Adapter):
def query(self, question: str) -> str: def query(
self,
question: str,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str:
raise NotImplementedError raise NotImplementedError
def add(self, *args: Any, **kwargs: Any) -> None: def add(self, *args: Any, **kwargs: Any) -> None:
@@ -34,28 +44,149 @@ class RagTool(BaseTool):
name: str = "Knowledge base" name: str = "Knowledge base"
description: str = "A knowledge base that can be used to answer questions." description: str = "A knowledge base that can be used to answer questions."
summarize: bool = False summarize: bool = False
similarity_threshold: float = 0.6
limit: int = 5
adapter: Adapter = Field(default_factory=_AdapterPlaceholder) adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
config: dict[str, Any] | None = None config: Any | None = None
@model_validator(mode="after") @model_validator(mode="after")
def _set_default_adapter(self): def _set_default_adapter(self):
if isinstance(self.adapter, RagTool._AdapterPlaceholder): if isinstance(self.adapter, RagTool._AdapterPlaceholder):
try: from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
from embedchain import App
except ImportError:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter parsed_config = self._parse_config(self.config)
with portalocker.Lock("crewai-rag-tool.lock", timeout=10): self.adapter = CrewAIRagAdapter(
app = App.from_config(config=self.config) if self.config else App() collection_name="rag_tool_collection",
summarize=self.summarize,
self.adapter = EmbedchainAdapter( similarity_threshold=self.similarity_threshold,
embedchain_app=app, summarize=self.summarize limit=self.limit,
config=parsed_config,
) )
return self return self
def _parse_config(self, config: Any) -> Any:
"""Parse complex config format to extract provider-specific config.
Raises:
ValueError: If the config format is invalid or uses unsupported providers.
"""
if config is None:
return None
if isinstance(config, dict) and "provider" in config:
return config
if isinstance(config, dict):
if "vectordb" in config:
vectordb_config = config["vectordb"]
if isinstance(vectordb_config, dict) and "provider" in vectordb_config:
provider = vectordb_config["provider"]
provider_config = vectordb_config.get("config", {})
supported_providers = ["chromadb", "qdrant"]
if provider not in supported_providers:
raise ValueError(
f"Unsupported vector database provider: '{provider}'. "
f"CrewAI RAG currently supports: {', '.join(supported_providers)}."
)
embedding_config = config.get("embedding_model")
embedding_function = None
if embedding_config and isinstance(embedding_config, dict):
embedding_function = self._create_embedding_function(
embedding_config, provider
)
return self._create_provider_config(
provider, provider_config, embedding_function
)
else:
return None
else:
embedding_config = config.get("embedding_model")
embedding_function = None
if embedding_config and isinstance(embedding_config, dict):
embedding_function = self._create_embedding_function(
embedding_config, "chromadb"
)
return self._create_provider_config("chromadb", {}, embedding_function)
return config
@staticmethod
def _create_embedding_function(embedding_config: dict, provider: str) -> Any:
"""Create embedding function for the specified vector database provider."""
embedding_provider = embedding_config.get("provider")
embedding_model_config = embedding_config.get("config", {}).copy()
if "model" in embedding_model_config:
embedding_model_config["model_name"] = embedding_model_config.pop("model")
factory_config = {"provider": embedding_provider, **embedding_model_config}
if embedding_provider == "openai" and "api_key" not in factory_config:
api_key = os.getenv("OPENAI_API_KEY")
if api_key:
factory_config["api_key"] = api_key
print(f"Creating embedding function with config: {factory_config}")
if provider == "chromadb":
embedding_func = get_embedding_function(factory_config)
print(f"Created embedding function: {embedding_func}")
print(f"Embedding function type: {type(embedding_func)}")
return embedding_func
elif provider == "qdrant":
chromadb_func = get_embedding_function(factory_config)
def qdrant_embed_fn(text: str) -> list[float]:
"""Embed text using ChromaDB function and convert to list of floats for Qdrant.
Args:
text: The input text to embed.
Returns:
A list of floats representing the embedding.
"""
embeddings = chromadb_func([text])
return embeddings[0] if embeddings and len(embeddings) > 0 else []
return cast(Any, qdrant_embed_fn)
return None
@staticmethod
def _create_provider_config(
provider: str, provider_config: dict, embedding_function: Any
) -> Any:
"""Create proper provider config object."""
if provider == "chromadb":
from crewai.rag.chromadb.config import ChromaDBConfig
config_kwargs = {}
if embedding_function:
config_kwargs["embedding_function"] = embedding_function
config_kwargs.update(provider_config)
return ChromaDBConfig(**config_kwargs)
elif provider == "qdrant":
from crewai.rag.qdrant.config import QdrantConfig
config_kwargs = {}
if embedding_function:
config_kwargs["embedding_function"] = embedding_function
config_kwargs.update(provider_config)
return QdrantConfig(**config_kwargs)
return None
def add( def add(
self, self,
*args: Any, *args: Any,
@@ -66,5 +197,13 @@ class RagTool(BaseTool):
def _run( def _run(
self, self,
query: str, query: str,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
return f"Relevant Content:\n{self.adapter.query(query)}" threshold = (
similarity_threshold
if similarity_threshold is not None
else self.similarity_threshold
)
result_limit = limit if limit is not None else self.limit
return f"Relevant Content:\n{self.adapter.query(query, similarity_threshold=threshold, limit=result_limit)}"

View File

@@ -39,7 +39,9 @@ class TXTSearchTool(RagTool):
self, self,
search_query: str, search_query: str,
txt: Optional[str] = None, txt: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if txt is not None: if txt is not None:
self.add(txt) self.add(txt)
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -1,14 +1,9 @@
from typing import Any, Optional, Type from typing import Any, Optional, Type
try:
from embedchain.models.data_type import DataType
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class FixedWebsiteSearchToolSchema(BaseModel): class FixedWebsiteSearchToolSchema(BaseModel):
@@ -44,15 +39,15 @@ class WebsiteSearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, website: str) -> None: def add(self, website: str) -> None:
if not EMBEDCHAIN_AVAILABLE: super().add(website, data_type=DataType.WEBSITE)
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().add(website, data_type=DataType.WEB_PAGE)
def _run( def _run(
self, self,
search_query: str, search_query: str,
website: Optional[str] = None, website: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if website is not None: if website is not None:
self.add(website) self.add(website)
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -39,7 +39,9 @@ class XMLSearchTool(RagTool):
self, self,
search_query: str, search_query: str,
xml: Optional[str] = None, xml: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if xml is not None: if xml is not None:
self.add(xml) self.add(xml)
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -1,14 +1,9 @@
from typing import Any, Optional, Type from typing import Any, Optional, Type
try:
from embedchain.models.data_type import DataType
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class FixedYoutubeChannelSearchToolSchema(BaseModel): class FixedYoutubeChannelSearchToolSchema(BaseModel):
@@ -55,7 +50,9 @@ class YoutubeChannelSearchTool(RagTool):
self, self,
search_query: str, search_query: str,
youtube_channel_handle: Optional[str] = None, youtube_channel_handle: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if youtube_channel_handle is not None: if youtube_channel_handle is not None:
self.add(youtube_channel_handle) self.add(youtube_channel_handle)
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -1,14 +1,10 @@
from typing import Any, Optional, Type from typing import Any, Optional, Type
try:
from embedchain.models.data_type import DataType
EMBEDCHAIN_AVAILABLE = True
except ImportError:
EMBEDCHAIN_AVAILABLE = False
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
from crewai_tools.rag.data_types import DataType
class FixedYoutubeVideoSearchToolSchema(BaseModel): class FixedYoutubeVideoSearchToolSchema(BaseModel):
@@ -44,15 +40,15 @@ class YoutubeVideoSearchTool(RagTool):
self._generate_description() self._generate_description()
def add(self, youtube_video_url: str) -> None: def add(self, youtube_video_url: str) -> None:
if not EMBEDCHAIN_AVAILABLE:
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO) super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
def _run( def _run(
self, self,
search_query: str, search_query: str,
youtube_video_url: Optional[str] = None, youtube_video_url: Optional[str] = None,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str: ) -> str:
if youtube_video_url is not None: if youtube_video_url is not None:
self.add(youtube_video_url) self.add(youtube_video_url)
return super()._run(query=search_query) return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)

View File

@@ -1,43 +1,54 @@
import os from tempfile import TemporaryDirectory
from tempfile import NamedTemporaryFile
from typing import cast from typing import cast
from unittest import mock from pathlib import Path
from pytest import fixture
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
from crewai_tools.tools.rag.rag_tool import RagTool from crewai_tools.tools.rag.rag_tool import RagTool
@fixture(autouse=True) def test_rag_tool_initialization():
def mock_embedchain_db_uri(): """Test that RagTool initializes with CrewAI adapter by default."""
with NamedTemporaryFile() as tmp:
uri = f"sqlite:///{tmp.name}"
with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}):
yield
def test_custom_llm_and_embedder():
class MyTool(RagTool): class MyTool(RagTool):
pass pass
tool = MyTool( tool = MyTool()
config=dict(
llm=dict(
provider="openai",
config=dict(model="gpt-3.5-custom"),
),
embedder=dict(
provider="openai",
config=dict(model="text-embedding-3-custom"),
),
)
)
assert tool.adapter is not None assert tool.adapter is not None
assert isinstance(tool.adapter, EmbedchainAdapter) assert isinstance(tool.adapter, CrewAIRagAdapter)
adapter = cast(CrewAIRagAdapter, tool.adapter)
assert adapter.collection_name == "rag_tool_collection"
assert adapter._client is not None
adapter = cast(EmbedchainAdapter, tool.adapter)
assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom" def test_rag_tool_add_and_query():
assert ( """Test adding content and querying with RagTool."""
adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom" class MyTool(RagTool):
) pass
tool = MyTool()
tool.add("The sky is blue on a clear day.")
tool.add("Machine learning is a subset of artificial intelligence.")
result = tool._run(query="What color is the sky?")
assert "Relevant Content:" in result
result = tool._run(query="Tell me about machine learning")
assert "Relevant Content:" in result
def test_rag_tool_with_file():
"""Test RagTool with file content."""
with TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "test.txt"
test_file.write_text("Python is a programming language known for its simplicity.")
class MyTool(RagTool):
pass
tool = MyTool()
tool.add(str(test_file))
result = tool._run(query="What is Python?")
assert "Relevant Content:" in result

View File

@@ -1,11 +1,11 @@
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import ANY, MagicMock from unittest.mock import MagicMock
import pytest import pytest
from embedchain.models.data_type import DataType
from crewai_tools.rag.data_types import DataType
from crewai_tools.tools import ( from crewai_tools.tools import (
CodeDocsSearchTool, CodeDocsSearchTool,
CSVSearchTool, CSVSearchTool,
@@ -49,7 +49,7 @@ def test_pdf_search_tool(mock_adapter):
result = tool._run(query="test content") result = tool._run(query="test content")
assert "this is a test" in result.lower() assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE) mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
mock_adapter.query.assert_called_once_with("test content") mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
mock_adapter.query.reset_mock() mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock() mock_adapter.add.reset_mock()
@@ -58,7 +58,7 @@ def test_pdf_search_tool(mock_adapter):
result = tool._run(pdf="test.pdf", query="test content") result = tool._run(pdf="test.pdf", query="test content")
assert "this is a test" in result.lower() assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE) mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
mock_adapter.query.assert_called_once_with("test content") mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
def test_txt_search_tool(): def test_txt_search_tool():
@@ -82,7 +82,7 @@ def test_docx_search_tool(mock_adapter):
result = tool._run(search_query="test content") result = tool._run(search_query="test content")
assert "this is a test" in result.lower() assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX) mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
mock_adapter.query.assert_called_once_with("test content") mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
mock_adapter.query.reset_mock() mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock() mock_adapter.add.reset_mock()
@@ -91,7 +91,7 @@ def test_docx_search_tool(mock_adapter):
result = tool._run(docx="test.docx", search_query="test content") result = tool._run(docx="test.docx", search_query="test content")
assert "this is a test" in result.lower() assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX) mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
mock_adapter.query.assert_called_once_with("test content") mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
def test_json_search_tool(): def test_json_search_tool():
@@ -114,7 +114,7 @@ def test_xml_search_tool(mock_adapter):
result = tool._run(search_query="test XML", xml="test.xml") result = tool._run(search_query="test XML", xml="test.xml")
assert "this is a test" in result.lower() assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.xml") mock_adapter.add.assert_called_once_with("test.xml")
mock_adapter.query.assert_called_once_with("test XML") mock_adapter.query.assert_called_once_with("test XML", similarity_threshold=0.6, limit=5)
def test_csv_search_tool(): def test_csv_search_tool():
@@ -153,8 +153,8 @@ def test_website_search_tool(mock_adapter):
tool = WebsiteSearchTool(website=website, adapter=mock_adapter) tool = WebsiteSearchTool(website=website, adapter=mock_adapter)
result = tool._run(search_query=search_query) result = tool._run(search_query=search_query)
mock_adapter.query.assert_called_once_with("what is crewai?") mock_adapter.query.assert_called_once_with("what is crewai?", similarity_threshold=0.6, limit=5)
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE) mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE)
assert "this is a test" in result.lower() assert "this is a test" in result.lower()
@@ -164,8 +164,8 @@ def test_website_search_tool(mock_adapter):
tool = WebsiteSearchTool(adapter=mock_adapter) tool = WebsiteSearchTool(adapter=mock_adapter)
result = tool._run(website=website, search_query=search_query) result = tool._run(website=website, search_query=search_query)
mock_adapter.query.assert_called_once_with("what is crewai?") mock_adapter.query.assert_called_once_with("what is crewai?", similarity_threshold=0.6, limit=5)
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE) mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE)
assert "this is a test" in result.lower() assert "this is a test" in result.lower()
@@ -185,7 +185,7 @@ def test_youtube_video_search_tool(mock_adapter):
mock_adapter.add.assert_called_once_with( mock_adapter.add.assert_called_once_with(
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
) )
mock_adapter.query.assert_called_once_with(search_query) mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
mock_adapter.query.reset_mock() mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock() mock_adapter.add.reset_mock()
@@ -197,7 +197,7 @@ def test_youtube_video_search_tool(mock_adapter):
mock_adapter.add.assert_called_once_with( mock_adapter.add.assert_called_once_with(
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
) )
mock_adapter.query.assert_called_once_with(search_query) mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
def test_youtube_channel_search_tool(mock_adapter): def test_youtube_channel_search_tool(mock_adapter):
@@ -213,7 +213,7 @@ def test_youtube_channel_search_tool(mock_adapter):
mock_adapter.add.assert_called_once_with( mock_adapter.add.assert_called_once_with(
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
) )
mock_adapter.query.assert_called_once_with(search_query) mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
mock_adapter.query.reset_mock() mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock() mock_adapter.add.reset_mock()
@@ -227,7 +227,7 @@ def test_youtube_channel_search_tool(mock_adapter):
mock_adapter.add.assert_called_once_with( mock_adapter.add.assert_called_once_with(
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
) )
mock_adapter.query.assert_called_once_with(search_query) mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
def test_code_docs_search_tool(mock_adapter): def test_code_docs_search_tool(mock_adapter):
@@ -239,7 +239,7 @@ def test_code_docs_search_tool(mock_adapter):
result = tool._run(search_query=search_query) result = tool._run(search_query=search_query)
assert "test documentation" in result assert "test documentation" in result
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE) mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
mock_adapter.query.assert_called_once_with(search_query) mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
mock_adapter.query.reset_mock() mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock() mock_adapter.add.reset_mock()
@@ -248,7 +248,7 @@ def test_code_docs_search_tool(mock_adapter):
result = tool._run(docs_url=docs_url, search_query=search_query) result = tool._run(docs_url=docs_url, search_query=search_query)
assert "test documentation" in result assert "test documentation" in result
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE) mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
mock_adapter.query.assert_called_once_with(search_query) mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
def test_github_search_tool(mock_adapter): def test_github_search_tool(mock_adapter):
@@ -264,9 +264,11 @@ def test_github_search_tool(mock_adapter):
result = tool._run(search_query="tell me about crewai repo") result = tool._run(search_query="tell me about crewai repo")
assert "repo description" in result assert "repo description" in result
mock_adapter.add.assert_called_once_with( mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code", data_type="github", loader=ANY "https://github.com/crewai/crewai",
data_type=DataType.GITHUB,
metadata={"content_types": ["code"], "gh_token": "test_token"}
) )
mock_adapter.query.assert_called_once_with("tell me about crewai repo") mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
# ensure content types provided by run call is used # ensure content types provided by run call is used
mock_adapter.query.reset_mock() mock_adapter.query.reset_mock()
@@ -280,9 +282,11 @@ def test_github_search_tool(mock_adapter):
) )
assert "repo description" in result assert "repo description" in result
mock_adapter.add.assert_called_once_with( mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code,issue", data_type="github", loader=ANY "https://github.com/crewai/crewai",
data_type=DataType.GITHUB,
metadata={"content_types": ["code", "issue"], "gh_token": "test_token"}
) )
mock_adapter.query.assert_called_once_with("tell me about crewai repo") mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
# ensure default content types are used if not provided # ensure default content types are used if not provided
mock_adapter.query.reset_mock() mock_adapter.query.reset_mock()
@@ -295,9 +299,11 @@ def test_github_search_tool(mock_adapter):
) )
assert "repo description" in result assert "repo description" in result
mock_adapter.add.assert_called_once_with( mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code,repo,pr,issue", data_type="github", loader=ANY "https://github.com/crewai/crewai",
data_type=DataType.GITHUB,
metadata={"content_types": ["code", "repo", "pr", "issue"], "gh_token": "test_token"}
) )
mock_adapter.query.assert_called_once_with("tell me about crewai repo") mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
# ensure nothing is added if no repo is provided # ensure nothing is added if no repo is provided
mock_adapter.query.reset_mock() mock_adapter.query.reset_mock()
@@ -306,4 +312,4 @@ def test_github_search_tool(mock_adapter):
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter) tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(search_query="tell me about crewai repo") result = tool._run(search_query="tell me about crewai repo")
mock_adapter.add.assert_not_called() mock_adapter.add.assert_not_called()
mock_adapter.query.assert_called_once_with("tell me about crewai repo") mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)