mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: replace embedchain with native crewai adapter (#451)
- Remove embedchain adapter; add crewai rag adapter and update all search tools - Add loaders: pdf, youtube (video & channel), github, docs site, mysql, postgresql - Add configurable similarity threshold, limit params, and embedding_model support - Improve chromadb compatibility (sanitize metadata, convert columns, fix chunking) - Fix xml encoding, Python 3.10 issues, and youtube url spoofing - Update crewai dependency and instructions; refresh uv.lock - Update tests for new rag adapter and search params
This commit is contained in:
215
src/crewai_tools/adapters/crewai_rag_adapter.py
Normal file
215
src/crewai_tools/adapters/crewai_rag_adapter.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Adapter for CrewAI's native RAG system."""
|
||||
|
||||
from typing import Any, TypedDict, TypeAlias
|
||||
from typing_extensions import Unpack
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
data_type: DataType
|
||||
metadata: dict[str, Any]
|
||||
website: str
|
||||
url: str
|
||||
file_path: str | Path
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class CrewAIRagAdapter(Adapter):
|
||||
"""Adapter that uses CrewAI's native RAG system.
|
||||
|
||||
Supports custom vector database configuration through the config parameter.
|
||||
"""
|
||||
|
||||
collection_name: str = "default"
|
||||
summarize: bool = False
|
||||
similarity_threshold: float = 0.6
|
||||
limit: int = 5
|
||||
config: RagConfigType | None = None
|
||||
_client: BaseClient | None = PrivateAttr(default=None)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initialize the CrewAI RAG client after model initialization."""
|
||||
if self.config is not None:
|
||||
self._client = create_client(self.config)
|
||||
else:
|
||||
self._client = get_rag_client()
|
||||
self._client.get_or_create_collection(collection_name=self.collection_name)
|
||||
|
||||
def query(self, question: str, similarity_threshold: float | None = None, limit: int | None = None) -> str:
|
||||
"""Query the knowledge base with a question.
|
||||
|
||||
Args:
|
||||
question: The question to ask
|
||||
similarity_threshold: Minimum similarity score for results (default: 0.6)
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
Relevant content from the knowledge base
|
||||
"""
|
||||
search_limit = limit if limit is not None else self.limit
|
||||
search_threshold = similarity_threshold if similarity_threshold is not None else self.similarity_threshold
|
||||
|
||||
results: list[SearchResult] = self._client.search(
|
||||
collection_name=self.collection_name,
|
||||
query=question,
|
||||
limit=search_limit,
|
||||
score_threshold=search_threshold
|
||||
)
|
||||
|
||||
if not results:
|
||||
return "No relevant content found."
|
||||
|
||||
contents: list[str] = []
|
||||
for result in results:
|
||||
content: str = result.get("content", "")
|
||||
if content:
|
||||
contents.append(content)
|
||||
|
||||
return "\n\n".join(contents)
|
||||
|
||||
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
This method handles various input types and converts them to documents
|
||||
for the vector database. It supports the data_type parameter for
|
||||
compatibility with existing tools.
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
**kwargs: Additional parameters including data_type, metadata, etc.
|
||||
"""
|
||||
from crewai_tools.rag.data_types import DataTypes, DataType
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.rag.base_loader import LoaderResult
|
||||
import os
|
||||
|
||||
documents: list[BaseRecord] = []
|
||||
data_type: DataType | None = kwargs.get("data_type")
|
||||
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
||||
|
||||
for arg in args:
|
||||
source_ref: str
|
||||
if isinstance(arg, dict):
|
||||
source_ref = str(arg.get("source", arg.get("content", "")))
|
||||
else:
|
||||
source_ref = str(arg)
|
||||
|
||||
if not data_type:
|
||||
data_type = DataTypes.from_content(source_ref)
|
||||
|
||||
if data_type == DataType.DIRECTORY:
|
||||
if not os.path.isdir(source_ref):
|
||||
raise ValueError(f"Directory does not exist: {source_ref}")
|
||||
|
||||
# Define binary and non-text file extensions to skip
|
||||
binary_extensions = {'.pyc', '.pyo', '.png', '.jpg', '.jpeg', '.gif',
|
||||
'.bmp', '.ico', '.svg', '.webp', '.pdf', '.zip',
|
||||
'.tar', '.gz', '.bz2', '.7z', '.rar', '.exe',
|
||||
'.dll', '.so', '.dylib', '.bin', '.dat', '.db',
|
||||
'.sqlite', '.class', '.jar', '.war', '.ear'}
|
||||
|
||||
for root, dirs, files in os.walk(source_ref):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.')]
|
||||
|
||||
for filename in files:
|
||||
if filename.startswith('.'):
|
||||
continue
|
||||
|
||||
# Skip binary files based on extension
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
if file_ext in binary_extensions:
|
||||
continue
|
||||
|
||||
# Skip __pycache__ directories
|
||||
if '__pycache__' in root:
|
||||
continue
|
||||
|
||||
file_path: str = os.path.join(root, filename)
|
||||
try:
|
||||
file_data_type: DataType = DataTypes.from_content(file_path)
|
||||
file_loader = file_data_type.get_loader()
|
||||
file_chunker = file_data_type.get_chunker()
|
||||
|
||||
file_source = SourceContent(file_path)
|
||||
file_result: LoaderResult = file_loader.load(file_source)
|
||||
|
||||
file_chunks = file_chunker.chunk(file_result.content)
|
||||
|
||||
for chunk_idx, file_chunk in enumerate(file_chunks):
|
||||
file_metadata: dict[str, Any] = base_metadata.copy()
|
||||
file_metadata.update(file_result.metadata)
|
||||
file_metadata["data_type"] = str(file_data_type)
|
||||
file_metadata["file_path"] = file_path
|
||||
file_metadata["chunk_index"] = chunk_idx
|
||||
file_metadata["total_chunks"] = len(file_chunks)
|
||||
|
||||
if isinstance(arg, dict):
|
||||
file_metadata.update(arg.get("metadata", {}))
|
||||
|
||||
chunk_id = hashlib.sha256(f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()).hexdigest()
|
||||
|
||||
documents.append({
|
||||
"doc_id": chunk_id,
|
||||
"content": file_chunk,
|
||||
"metadata": sanitize_metadata_for_chromadb(file_metadata)
|
||||
})
|
||||
except Exception:
|
||||
# Silently skip files that can't be processed
|
||||
continue
|
||||
else:
|
||||
metadata: dict[str, Any] = base_metadata.copy()
|
||||
|
||||
if data_type in [DataType.PDF_FILE, DataType.TEXT_FILE, DataType.DOCX,
|
||||
DataType.CSV, DataType.JSON, DataType.XML, DataType.MDX]:
|
||||
if not os.path.isfile(source_ref):
|
||||
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||
|
||||
loader = data_type.get_loader()
|
||||
chunker = data_type.get_chunker()
|
||||
|
||||
source_content = SourceContent(source_ref)
|
||||
loader_result: LoaderResult = loader.load(source_content)
|
||||
|
||||
chunks = chunker.chunk(loader_result.content)
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_metadata: dict[str, Any] = metadata.copy()
|
||||
chunk_metadata.update(loader_result.metadata)
|
||||
chunk_metadata["data_type"] = str(data_type)
|
||||
chunk_metadata["chunk_index"] = i
|
||||
chunk_metadata["total_chunks"] = len(chunks)
|
||||
chunk_metadata["source"] = source_ref
|
||||
|
||||
if isinstance(arg, dict):
|
||||
chunk_metadata.update(arg.get("metadata", {}))
|
||||
|
||||
chunk_id = hashlib.sha256(f"{loader_result.doc_id}_{i}_{chunk}".encode()).hexdigest()
|
||||
|
||||
documents.append({
|
||||
"doc_id": chunk_id,
|
||||
"content": chunk,
|
||||
"metadata": sanitize_metadata_for_chromadb(chunk_metadata)
|
||||
})
|
||||
|
||||
if documents:
|
||||
self._client.add_documents(
|
||||
collection_name=self.collection_name,
|
||||
documents=documents
|
||||
)
|
||||
@@ -1,34 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
try:
|
||||
from embedchain import App
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
|
||||
class EmbedchainAdapter(Adapter):
|
||||
embedchain_app: Any # Will be App when embedchain is available
|
||||
summarize: bool = False
|
||||
|
||||
def __init__(self, **data):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**data)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
result, sources = self.embedchain_app.query(
|
||||
question, citations=True, dry_run=(not self.summarize)
|
||||
)
|
||||
if self.summarize:
|
||||
return result
|
||||
return "\n\n".join([source[0] for source in sources])
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.embedchain_app.add(*args, **kwargs)
|
||||
@@ -1,41 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
try:
|
||||
from embedchain import App
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
|
||||
class PDFEmbedchainAdapter(Adapter):
|
||||
embedchain_app: Any # Will be App when embedchain is available
|
||||
summarize: bool = False
|
||||
src: Optional[str] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**data)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
where = (
|
||||
{"app_id": self.embedchain_app.config.id, "source": self.src}
|
||||
if self.src
|
||||
else None
|
||||
)
|
||||
result, sources = self.embedchain_app.query(
|
||||
question, citations=True, dry_run=(not self.summarize), where=where
|
||||
)
|
||||
if self.summarize:
|
||||
return result
|
||||
return "\n\n".join([source[0] for source in sources])
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.src = args[0] if args else None
|
||||
self.embedchain_app.add(*args, **kwargs)
|
||||
@@ -112,7 +112,10 @@ class RecursiveCharacterTextSplitter:
|
||||
if separator == "":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
if self._keep_separator and separator == " ":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
@@ -133,7 +136,10 @@ class RecursiveCharacterTextSplitter:
|
||||
if separator == "":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
if self._keep_separator and separator == " ":
|
||||
doc = "".join(current_doc)
|
||||
else:
|
||||
doc = separator.join(current_doc)
|
||||
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
|
||||
@@ -25,6 +25,8 @@ class DataType(str, Enum):
|
||||
# Web types
|
||||
WEBSITE = "website"
|
||||
DOCS_SITE = "docs_site"
|
||||
YOUTUBE_VIDEO = "youtube_video"
|
||||
YOUTUBE_CHANNEL = "youtube_channel"
|
||||
|
||||
# Raw types
|
||||
TEXT = "text"
|
||||
@@ -34,6 +36,7 @@ class DataType(str, Enum):
|
||||
from importlib import import_module
|
||||
|
||||
chunkers = {
|
||||
DataType.PDF_FILE: ("text_chunker", "TextChunker"),
|
||||
DataType.TEXT_FILE: ("text_chunker", "TextChunker"),
|
||||
DataType.TEXT: ("text_chunker", "TextChunker"),
|
||||
DataType.DOCX: ("text_chunker", "DocxChunker"),
|
||||
@@ -45,9 +48,18 @@ class DataType(str, Enum):
|
||||
DataType.XML: ("structured_chunker", "XmlChunker"),
|
||||
|
||||
DataType.WEBSITE: ("web_chunker", "WebsiteChunker"),
|
||||
DataType.DIRECTORY: ("text_chunker", "TextChunker"),
|
||||
DataType.YOUTUBE_VIDEO: ("text_chunker", "TextChunker"),
|
||||
DataType.YOUTUBE_CHANNEL: ("text_chunker", "TextChunker"),
|
||||
DataType.GITHUB: ("text_chunker", "TextChunker"),
|
||||
DataType.DOCS_SITE: ("text_chunker", "TextChunker"),
|
||||
DataType.MYSQL: ("text_chunker", "TextChunker"),
|
||||
DataType.POSTGRES: ("text_chunker", "TextChunker"),
|
||||
}
|
||||
|
||||
module_name, class_name = chunkers.get(self, ("default_chunker", "DefaultChunker"))
|
||||
if self not in chunkers:
|
||||
raise ValueError(f"No chunker defined for {self}")
|
||||
module_name, class_name = chunkers[self]
|
||||
module_path = f"crewai_tools.rag.chunkers.{module_name}"
|
||||
|
||||
try:
|
||||
@@ -60,6 +72,7 @@ class DataType(str, Enum):
|
||||
from importlib import import_module
|
||||
|
||||
loaders = {
|
||||
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
|
||||
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||
DataType.TEXT: ("text_loader", "TextLoader"),
|
||||
DataType.XML: ("xml_loader", "XMLLoader"),
|
||||
@@ -69,9 +82,17 @@ class DataType(str, Enum):
|
||||
DataType.DOCX: ("docx_loader", "DOCXLoader"),
|
||||
DataType.CSV: ("csv_loader", "CSVLoader"),
|
||||
DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"),
|
||||
DataType.YOUTUBE_VIDEO: ("youtube_video_loader", "YoutubeVideoLoader"),
|
||||
DataType.YOUTUBE_CHANNEL: ("youtube_channel_loader", "YoutubeChannelLoader"),
|
||||
DataType.GITHUB: ("github_loader", "GithubLoader"),
|
||||
DataType.DOCS_SITE: ("docs_site_loader", "DocsSiteLoader"),
|
||||
DataType.MYSQL: ("mysql_loader", "MySQLLoader"),
|
||||
DataType.POSTGRES: ("postgres_loader", "PostgresLoader"),
|
||||
}
|
||||
|
||||
module_name, class_name = loaders.get(self, ("text_loader", "TextLoader"))
|
||||
if self not in loaders:
|
||||
raise ValueError(f"No loader defined for {self}")
|
||||
module_name, class_name = loaders[self]
|
||||
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
|
||||
@@ -6,6 +6,9 @@ from crewai_tools.rag.loaders.json_loader import JSONLoader
|
||||
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
|
||||
from crewai_tools.rag.loaders.csv_loader import CSVLoader
|
||||
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
|
||||
from crewai_tools.rag.loaders.pdf_loader import PDFLoader
|
||||
from crewai_tools.rag.loaders.youtube_video_loader import YoutubeVideoLoader
|
||||
from crewai_tools.rag.loaders.youtube_channel_loader import YoutubeChannelLoader
|
||||
|
||||
__all__ = [
|
||||
"TextFileLoader",
|
||||
@@ -17,4 +20,7 @@ __all__ = [
|
||||
"DOCXLoader",
|
||||
"CSVLoader",
|
||||
"DirectoryLoader",
|
||||
"PDFLoader",
|
||||
"YoutubeVideoLoader",
|
||||
"YoutubeChannelLoader",
|
||||
]
|
||||
|
||||
98
src/crewai_tools/rag/loaders/docs_site_loader.py
Normal file
98
src/crewai_tools/rag/loaders/docs_site_loader.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Documentation site loader."""
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class DocsSiteLoader(BaseLoader):
|
||||
"""Loader for documentation websites."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a documentation site.
|
||||
|
||||
Args:
|
||||
source: Documentation site URL
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
LoaderResult with documentation content
|
||||
"""
|
||||
docs_url = source.source
|
||||
|
||||
try:
|
||||
response = requests.get(docs_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"Unable to fetch documentation from {docs_url}: {e}")
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.decompose()
|
||||
|
||||
title = soup.find("title")
|
||||
title_text = title.get_text(strip=True) if title else "Documentation"
|
||||
|
||||
main_content = None
|
||||
for selector in ["main", "article", '[role="main"]', ".content", "#content", ".documentation"]:
|
||||
main_content = soup.select_one(selector)
|
||||
if main_content:
|
||||
break
|
||||
|
||||
if not main_content:
|
||||
main_content = soup.find("body")
|
||||
|
||||
if not main_content:
|
||||
raise ValueError(f"Unable to extract content from documentation site: {docs_url}")
|
||||
|
||||
text_parts = [f"Title: {title_text}", ""]
|
||||
|
||||
headings = main_content.find_all(["h1", "h2", "h3"])
|
||||
if headings:
|
||||
text_parts.append("Table of Contents:")
|
||||
for heading in headings[:15]:
|
||||
level = int(heading.name[1])
|
||||
indent = " " * (level - 1)
|
||||
text_parts.append(f"{indent}- {heading.get_text(strip=True)}")
|
||||
text_parts.append("")
|
||||
|
||||
text = main_content.get_text(separator="\n", strip=True)
|
||||
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
||||
text_parts.extend(lines)
|
||||
|
||||
nav_links = []
|
||||
for nav_selector in ["nav", ".sidebar", ".toc", ".navigation"]:
|
||||
nav = soup.select_one(nav_selector)
|
||||
if nav:
|
||||
links = nav.find_all("a", href=True)
|
||||
for link in links[:20]:
|
||||
href = link["href"]
|
||||
if not href.startswith(("http://", "https://", "mailto:", "#")):
|
||||
full_url = urljoin(docs_url, href)
|
||||
nav_links.append(f"- {link.get_text(strip=True)}: {full_url}")
|
||||
|
||||
if nav_links:
|
||||
text_parts.append("")
|
||||
text_parts.append("Related documentation pages:")
|
||||
text_parts.extend(nav_links[:10])
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
if len(content) > 100000:
|
||||
content = content[:100000] + "\n\n[Content truncated...]"
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": docs_url,
|
||||
"title": title_text,
|
||||
"domain": urlparse(docs_url).netloc
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=docs_url, content=content)
|
||||
)
|
||||
110
src/crewai_tools/rag/loaders/github_loader.py
Normal file
110
src/crewai_tools/rag/loaders/github_loader.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""GitHub repository content loader."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from github import Github, GithubException
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class GithubLoader(BaseLoader):
|
||||
"""Loader for GitHub repository content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a GitHub repository.
|
||||
|
||||
Args:
|
||||
source: GitHub repository URL
|
||||
**kwargs: Additional arguments including gh_token and content_types
|
||||
|
||||
Returns:
|
||||
LoaderResult with repository content
|
||||
"""
|
||||
metadata = kwargs.get("metadata", {})
|
||||
gh_token = metadata.get("gh_token")
|
||||
content_types = metadata.get("content_types", ["code", "repo"])
|
||||
|
||||
repo_url = source.source
|
||||
if not repo_url.startswith("https://github.com/"):
|
||||
raise ValueError(f"Invalid GitHub URL: {repo_url}")
|
||||
|
||||
parts = repo_url.replace("https://github.com/", "").strip("/").split("/")
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid GitHub repository URL: {repo_url}")
|
||||
|
||||
repo_name = f"{parts[0]}/{parts[1]}"
|
||||
|
||||
g = Github(gh_token) if gh_token else Github()
|
||||
|
||||
try:
|
||||
repo = g.get_repo(repo_name)
|
||||
except GithubException as e:
|
||||
raise ValueError(f"Unable to access repository {repo_name}: {e}")
|
||||
|
||||
all_content = []
|
||||
|
||||
if "repo" in content_types:
|
||||
all_content.append(f"Repository: {repo.full_name}")
|
||||
all_content.append(f"Description: {repo.description or 'No description'}")
|
||||
all_content.append(f"Language: {repo.language or 'Not specified'}")
|
||||
all_content.append(f"Stars: {repo.stargazers_count}")
|
||||
all_content.append(f"Forks: {repo.forks_count}")
|
||||
all_content.append("")
|
||||
|
||||
if "code" in content_types:
|
||||
try:
|
||||
readme = repo.get_readme()
|
||||
all_content.append("README:")
|
||||
all_content.append(readme.decoded_content.decode("utf-8", errors="ignore"))
|
||||
all_content.append("")
|
||||
except GithubException:
|
||||
pass
|
||||
|
||||
try:
|
||||
contents = repo.get_contents("")
|
||||
if isinstance(contents, list):
|
||||
all_content.append("Repository structure:")
|
||||
for content_file in contents[:20]:
|
||||
all_content.append(f"- {content_file.path} ({content_file.type})")
|
||||
all_content.append("")
|
||||
except GithubException:
|
||||
pass
|
||||
|
||||
if "pr" in content_types:
|
||||
prs = repo.get_pulls(state="open")
|
||||
pr_list = list(prs[:5])
|
||||
if pr_list:
|
||||
all_content.append("Recent Pull Requests:")
|
||||
for pr in pr_list:
|
||||
all_content.append(f"- PR #{pr.number}: {pr.title}")
|
||||
if pr.body:
|
||||
body_preview = pr.body[:200].replace("\n", " ")
|
||||
all_content.append(f" {body_preview}")
|
||||
all_content.append("")
|
||||
|
||||
if "issue" in content_types:
|
||||
issues = repo.get_issues(state="open")
|
||||
issue_list = [i for i in list(issues[:10]) if not i.pull_request][:5]
|
||||
if issue_list:
|
||||
all_content.append("Recent Issues:")
|
||||
for issue in issue_list:
|
||||
all_content.append(f"- Issue #{issue.number}: {issue.title}")
|
||||
if issue.body:
|
||||
body_preview = issue.body[:200].replace("\n", " ")
|
||||
all_content.append(f" {body_preview}")
|
||||
all_content.append("")
|
||||
|
||||
if not all_content:
|
||||
raise ValueError(f"No content could be loaded from repository: {repo_url}")
|
||||
|
||||
content = "\n".join(all_content)
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": repo_url,
|
||||
"repo": repo_name,
|
||||
"content_types": content_types
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=repo_url, content=content)
|
||||
)
|
||||
99
src/crewai_tools/rag/loaders/mysql_loader.py
Normal file
99
src/crewai_tools/rag/loaders/mysql_loader.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""MySQL database loader."""
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pymysql
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class MySQLLoader(BaseLoader):
|
||||
"""Loader for MySQL database content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a MySQL database table.
|
||||
|
||||
Args:
|
||||
source: SQL query (e.g., "SELECT * FROM table_name")
|
||||
**kwargs: Additional arguments including db_uri
|
||||
|
||||
Returns:
|
||||
LoaderResult with database content
|
||||
"""
|
||||
metadata = kwargs.get("metadata", {})
|
||||
db_uri = metadata.get("db_uri")
|
||||
|
||||
if not db_uri:
|
||||
raise ValueError("Database URI is required for MySQL loader")
|
||||
|
||||
query = source.source
|
||||
|
||||
parsed = urlparse(db_uri)
|
||||
if parsed.scheme not in ["mysql", "mysql+pymysql"]:
|
||||
raise ValueError(f"Invalid MySQL URI scheme: {parsed.scheme}")
|
||||
|
||||
connection_params = {
|
||||
"host": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 3306,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||
"charset": "utf8mb4",
|
||||
"cursorclass": pymysql.cursors.DictCursor
|
||||
}
|
||||
|
||||
if not connection_params["database"]:
|
||||
raise ValueError("Database name is required in the URI")
|
||||
|
||||
try:
|
||||
connection = pymysql.connect(**connection_params)
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
content = "No data found in the table"
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={"source": query, "row_count": 0},
|
||||
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
||||
)
|
||||
|
||||
text_parts = []
|
||||
|
||||
columns = list(rows[0].keys())
|
||||
text_parts.append(f"Columns: {', '.join(columns)}")
|
||||
text_parts.append(f"Total rows: {len(rows)}")
|
||||
text_parts.append("")
|
||||
|
||||
for i, row in enumerate(rows, 1):
|
||||
text_parts.append(f"Row {i}:")
|
||||
for col, val in row.items():
|
||||
if val is not None:
|
||||
text_parts.append(f" {col}: {val}")
|
||||
text_parts.append("")
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
if len(content) > 100000:
|
||||
content = content[:100000] + "\n\n[Content truncated...]"
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": query,
|
||||
"database": connection_params["database"],
|
||||
"row_count": len(rows),
|
||||
"columns": columns
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
||||
)
|
||||
finally:
|
||||
connection.close()
|
||||
except pymysql.Error as e:
|
||||
raise ValueError(f"MySQL database error: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data from MySQL: {e}")
|
||||
72
src/crewai_tools/rag/loaders/pdf_loader.py
Normal file
72
src/crewai_tools/rag/loaders/pdf_loader.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""PDF loader for extracting text from PDF files."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PDFLoader(BaseLoader):
|
||||
"""Loader for PDF files."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract text from a PDF file.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist
|
||||
ImportError: If required PDF libraries aren't installed
|
||||
"""
|
||||
try:
|
||||
import pypdf
|
||||
except ImportError:
|
||||
try:
|
||||
import PyPDF2 as pypdf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PDF support requires pypdf or PyPDF2. "
|
||||
"Install with: uv add pypdf"
|
||||
)
|
||||
|
||||
file_path = source.source
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
|
||||
text_content = []
|
||||
metadata: dict[str, Any] = {
|
||||
"source": str(file_path),
|
||||
"file_name": Path(file_path).name,
|
||||
"file_type": "pdf"
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as file:
|
||||
pdf_reader = pypdf.PdfReader(file)
|
||||
metadata["num_pages"] = len(pdf_reader.pages)
|
||||
|
||||
for page_num, page in enumerate(pdf_reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading PDF file {file_path}: {str(e)}")
|
||||
|
||||
if not text_content:
|
||||
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||
else:
|
||||
content = "\n\n".join(text_content)
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=str(file_path),
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content)
|
||||
)
|
||||
99
src/crewai_tools/rag/loaders/postgres_loader.py
Normal file
99
src/crewai_tools/rag/loaders/postgres_loader.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""PostgreSQL database loader."""
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PostgresLoader(BaseLoader):
|
||||
"""Loader for PostgreSQL database content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a PostgreSQL database table.
|
||||
|
||||
Args:
|
||||
source: SQL query (e.g., "SELECT * FROM table_name")
|
||||
**kwargs: Additional arguments including db_uri
|
||||
|
||||
Returns:
|
||||
LoaderResult with database content
|
||||
"""
|
||||
metadata = kwargs.get("metadata", {})
|
||||
db_uri = metadata.get("db_uri")
|
||||
|
||||
if not db_uri:
|
||||
raise ValueError("Database URI is required for PostgreSQL loader")
|
||||
|
||||
query = source.source
|
||||
|
||||
parsed = urlparse(db_uri)
|
||||
if parsed.scheme not in ["postgresql", "postgres", "postgresql+psycopg2"]:
|
||||
raise ValueError(f"Invalid PostgreSQL URI scheme: {parsed.scheme}")
|
||||
|
||||
connection_params = {
|
||||
"host": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 5432,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||
"cursor_factory": RealDictCursor
|
||||
}
|
||||
|
||||
if not connection_params["database"]:
|
||||
raise ValueError("Database name is required in the URI")
|
||||
|
||||
try:
|
||||
connection = psycopg2.connect(**connection_params)
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
content = "No data found in the table"
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={"source": query, "row_count": 0},
|
||||
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
||||
)
|
||||
|
||||
text_parts = []
|
||||
|
||||
columns = list(rows[0].keys())
|
||||
text_parts.append(f"Columns: {', '.join(columns)}")
|
||||
text_parts.append(f"Total rows: {len(rows)}")
|
||||
text_parts.append("")
|
||||
|
||||
for i, row in enumerate(rows, 1):
|
||||
text_parts.append(f"Row {i}:")
|
||||
for col, val in row.items():
|
||||
if val is not None:
|
||||
text_parts.append(f" {col}: {val}")
|
||||
text_parts.append("")
|
||||
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
if len(content) > 100000:
|
||||
content = content[:100000] + "\n\n[Content truncated...]"
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata={
|
||||
"source": query,
|
||||
"database": connection_params["database"],
|
||||
"row_count": len(rows),
|
||||
"columns": columns
|
||||
},
|
||||
doc_id=self.generate_doc_id(source_ref=query, content=content)
|
||||
)
|
||||
finally:
|
||||
connection.close()
|
||||
except psycopg2.Error as e:
|
||||
raise ValueError(f"PostgreSQL database error: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data from PostgreSQL: {e}")
|
||||
@@ -11,7 +11,7 @@ class XMLLoader(BaseLoader):
|
||||
|
||||
if source_content.is_url():
|
||||
content = self._load_from_url(source_ref, kwargs)
|
||||
elif os.path.exists(source_ref):
|
||||
elif source_content.path_exists():
|
||||
content = self._load_from_file(source_ref)
|
||||
|
||||
return self._parse_xml(content, source_ref)
|
||||
|
||||
141
src/crewai_tools/rag/loaders/youtube_channel_loader.py
Normal file
141
src/crewai_tools/rag/loaders/youtube_channel_loader.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""YouTube channel loader for extracting content from YouTube channels."""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class YoutubeChannelLoader(BaseLoader):
|
||||
"""Loader for YouTube channels."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract content from a YouTube channel.
|
||||
|
||||
Args:
|
||||
source: The source content containing the YouTube channel URL
|
||||
|
||||
Returns:
|
||||
LoaderResult with channel content
|
||||
|
||||
Raises:
|
||||
ImportError: If required YouTube libraries aren't installed
|
||||
ValueError: If the URL is not a valid YouTube channel URL
|
||||
"""
|
||||
try:
|
||||
from pytube import Channel
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"YouTube channel support requires pytube. "
|
||||
"Install with: uv add pytube"
|
||||
)
|
||||
|
||||
channel_url = source.source
|
||||
|
||||
if not any(pattern in channel_url for pattern in ['youtube.com/channel/', 'youtube.com/c/', 'youtube.com/@', 'youtube.com/user/']):
|
||||
raise ValueError(f"Invalid YouTube channel URL: {channel_url}")
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"source": channel_url,
|
||||
"data_type": "youtube_channel"
|
||||
}
|
||||
|
||||
try:
|
||||
channel = Channel(channel_url)
|
||||
|
||||
metadata["channel_name"] = channel.channel_name
|
||||
metadata["channel_id"] = channel.channel_id
|
||||
|
||||
max_videos = kwargs.get('max_videos', 10)
|
||||
video_urls = list(channel.video_urls)[:max_videos]
|
||||
metadata["num_videos_loaded"] = len(video_urls)
|
||||
metadata["total_videos"] = len(list(channel.video_urls))
|
||||
|
||||
content_parts = [
|
||||
f"YouTube Channel: {channel.channel_name}",
|
||||
f"Channel ID: {channel.channel_id}",
|
||||
f"Total Videos: {metadata['total_videos']}",
|
||||
f"Videos Loaded: {metadata['num_videos_loaded']}",
|
||||
"\n--- Video Summaries ---\n"
|
||||
]
|
||||
|
||||
try:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
from pytube import YouTube
|
||||
|
||||
for i, video_url in enumerate(video_urls, 1):
|
||||
try:
|
||||
video_id = self._extract_video_id(video_url)
|
||||
if not video_id:
|
||||
continue
|
||||
yt = YouTube(video_url)
|
||||
title = yt.title or f"Video {i}"
|
||||
description = yt.description[:200] if yt.description else "No description"
|
||||
|
||||
content_parts.append(f"\n{i}. {title}")
|
||||
content_parts.append(f" URL: {video_url}")
|
||||
content_parts.append(f" Description: {description}...")
|
||||
|
||||
try:
|
||||
api = YouTubeTranscriptApi()
|
||||
transcript_list = api.list(video_id)
|
||||
transcript = None
|
||||
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(['en'])
|
||||
except:
|
||||
try:
|
||||
transcript = transcript_list.find_generated_transcript(['en'])
|
||||
except:
|
||||
transcript = next(iter(transcript_list), None)
|
||||
|
||||
if transcript:
|
||||
transcript_data = transcript.fetch()
|
||||
text_parts = []
|
||||
char_count = 0
|
||||
for entry in transcript_data:
|
||||
text = entry.text.strip() if hasattr(entry, 'text') else ''
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
char_count += len(text)
|
||||
if char_count > 500:
|
||||
break
|
||||
|
||||
if text_parts:
|
||||
preview = ' '.join(text_parts)[:500]
|
||||
content_parts.append(f" Transcript Preview: {preview}...")
|
||||
except:
|
||||
content_parts.append(" Transcript: Not available")
|
||||
|
||||
except Exception as e:
|
||||
content_parts.append(f"\n{i}. Error loading video: {str(e)}")
|
||||
|
||||
except ImportError:
|
||||
for i, video_url in enumerate(video_urls, 1):
|
||||
content_parts.append(f"\n{i}. {video_url}")
|
||||
|
||||
content = '\n'.join(content_parts)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to load YouTube channel {channel_url}: {str(e)}") from e
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=channel_url,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=channel_url, content=content)
|
||||
)
|
||||
|
||||
def _extract_video_id(self, url: str) -> str | None:
|
||||
"""Extract video ID from YouTube URL."""
|
||||
patterns = [
|
||||
r'(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
123
src/crewai_tools/rag/loaders/youtube_video_loader.py
Normal file
123
src/crewai_tools/rag/loaders/youtube_video_loader.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""YouTube video loader for extracting transcripts from YouTube videos."""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class YoutubeVideoLoader(BaseLoader):
|
||||
"""Loader for YouTube videos."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract transcript from a YouTube video.
|
||||
|
||||
Args:
|
||||
source: The source content containing the YouTube URL
|
||||
|
||||
Returns:
|
||||
LoaderResult with transcript content
|
||||
|
||||
Raises:
|
||||
ImportError: If required YouTube libraries aren't installed
|
||||
ValueError: If the URL is not a valid YouTube video URL
|
||||
"""
|
||||
try:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"YouTube support requires youtube-transcript-api. "
|
||||
"Install with: uv add youtube-transcript-api"
|
||||
)
|
||||
|
||||
video_url = source.source
|
||||
video_id = self._extract_video_id(video_url)
|
||||
|
||||
if not video_id:
|
||||
raise ValueError(f"Invalid YouTube URL: {video_url}")
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"source": video_url,
|
||||
"video_id": video_id,
|
||||
"data_type": "youtube_video"
|
||||
}
|
||||
|
||||
try:
|
||||
api = YouTubeTranscriptApi()
|
||||
transcript_list = api.list(video_id)
|
||||
|
||||
transcript = None
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(['en'])
|
||||
except:
|
||||
try:
|
||||
transcript = transcript_list.find_generated_transcript(['en'])
|
||||
except:
|
||||
transcript = next(iter(transcript_list))
|
||||
|
||||
if transcript:
|
||||
metadata["language"] = transcript.language
|
||||
metadata["is_generated"] = transcript.is_generated
|
||||
|
||||
transcript_data = transcript.fetch()
|
||||
|
||||
text_content = []
|
||||
for entry in transcript_data:
|
||||
text = entry.text.strip() if hasattr(entry, 'text') else ''
|
||||
if text:
|
||||
text_content.append(text)
|
||||
|
||||
content = ' '.join(text_content)
|
||||
|
||||
try:
|
||||
from pytube import YouTube
|
||||
yt = YouTube(video_url)
|
||||
metadata["title"] = yt.title
|
||||
metadata["author"] = yt.author
|
||||
metadata["length_seconds"] = yt.length
|
||||
metadata["description"] = yt.description[:500] if yt.description else None
|
||||
|
||||
if yt.title:
|
||||
content = f"Title: {yt.title}\n\nAuthor: {yt.author or 'Unknown'}\n\nTranscript:\n{content}"
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"No transcript available for YouTube video: {video_id}")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to extract transcript from YouTube video {video_id}: {str(e)}") from e
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=video_url,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=video_url, content=content)
|
||||
)
|
||||
|
||||
def _extract_video_id(self, url: str) -> str | None:
|
||||
"""Extract video ID from various YouTube URL formats."""
|
||||
patterns = [
|
||||
r'(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname
|
||||
if hostname:
|
||||
hostname_lower = hostname.lower()
|
||||
# Allow youtube.com and any subdomain of youtube.com, plus youtu.be shortener
|
||||
if hostname_lower == 'youtube.com' or hostname_lower.endswith('.youtube.com') or hostname_lower == 'youtu.be':
|
||||
query_params = parse_qs(parsed.query)
|
||||
if 'v' in query_params:
|
||||
return query_params['v'][0]
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -1,4 +1,29 @@
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
def compute_sha256(content: str) -> str:
|
||||
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
|
||||
def sanitize_metadata_for_chromadb(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sanitize metadata to ensure ChromaDB compatibility.
|
||||
|
||||
ChromaDB only accepts str, int, float, or bool values in metadata.
|
||||
This function converts other types to strings.
|
||||
|
||||
Args:
|
||||
metadata: Dictionary of metadata to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized metadata dictionary with only ChromaDB-compatible types
|
||||
"""
|
||||
sanitized = {}
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
sanitized[key] = value
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Convert lists/tuples to pipe-separated strings
|
||||
sanitized[key] = " | ".join(str(v) for v in value)
|
||||
else:
|
||||
# Convert other types to string
|
||||
sanitized[key] = str(value)
|
||||
return sanitized
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedCodeDocsSearchToolSchema(BaseModel):
|
||||
@@ -42,15 +38,15 @@ class CodeDocsSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, docs_url: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(docs_url, data_type=DataType.DOCS_SITE)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
docs_url: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if docs_url is not None:
|
||||
self.add(docs_url)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedCSVSearchToolSchema(BaseModel):
|
||||
@@ -42,15 +38,16 @@ class CSVSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, csv: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(csv, data_type=DataType.CSV)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
csv: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if csv is not None:
|
||||
self.add(csv)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedDirectorySearchToolSchema(BaseModel):
|
||||
@@ -34,8 +29,6 @@ class DirectorySearchTool(RagTool):
|
||||
args_schema: Type[BaseModel] = DirectorySearchToolSchema
|
||||
|
||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
@@ -44,16 +37,15 @@ class DirectorySearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, directory: str) -> None:
|
||||
super().add(
|
||||
directory,
|
||||
loader=DirectoryLoader(config=dict(recursive=True)),
|
||||
)
|
||||
super().add(directory, data_type=DataType.DIRECTORY)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
directory: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedDOCXSearchToolSchema(BaseModel):
|
||||
@@ -48,15 +44,15 @@ class DOCXSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, docx: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(docx, data_type=DataType.DOCX)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
docx: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Any:
|
||||
if docx is not None:
|
||||
self.add(docx)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import List, Optional, Type, Any
|
||||
from typing import List, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.loaders.github import GithubLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedGithubSearchToolSchema(BaseModel):
|
||||
@@ -42,7 +37,6 @@ class GithubSearchTool(RagTool):
|
||||
default_factory=lambda: ["code", "repo", "pr", "issue"],
|
||||
description="Content types you want to be included search, options: [code, repo, pr, issue]",
|
||||
)
|
||||
_loader: Any | None = PrivateAttr(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -50,10 +44,7 @@ class GithubSearchTool(RagTool):
|
||||
content_types: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
self._loader = GithubLoader(config={"token": self.gh_token})
|
||||
|
||||
if github_repo and content_types:
|
||||
self.add(repo=github_repo, content_types=content_types)
|
||||
@@ -67,11 +58,10 @@ class GithubSearchTool(RagTool):
|
||||
content_types: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
content_types = content_types or self.content_types
|
||||
|
||||
super().add(
|
||||
f"repo:{repo} type:{','.join(content_types)}",
|
||||
data_type="github",
|
||||
loader=self._loader,
|
||||
f"https://github.com/{repo}",
|
||||
data_type=DataType.GITHUB,
|
||||
metadata={"content_types": content_types, "gh_token": self.gh_token}
|
||||
)
|
||||
|
||||
def _run(
|
||||
@@ -79,10 +69,12 @@ class GithubSearchTool(RagTool):
|
||||
search_query: str,
|
||||
github_repo: Optional[str] = None,
|
||||
content_types: Optional[List[str]] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if github_repo:
|
||||
self.add(
|
||||
repo=github_repo,
|
||||
content_types=content_types,
|
||||
)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -41,7 +41,9 @@ class JSONSearchTool(RagTool):
|
||||
self,
|
||||
search_query: str,
|
||||
json_path: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if json_path is not None:
|
||||
self.add(json_path)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -2,13 +2,9 @@ from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedMDXSearchToolSchema(BaseModel):
|
||||
@@ -42,15 +38,15 @@ class MDXSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, mdx: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(mdx, data_type=DataType.MDX)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
mdx: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if mdx is not None:
|
||||
self.add(mdx)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import Any, Type
|
||||
|
||||
try:
|
||||
from embedchain.loaders.mysql import MySQLLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class MySQLSearchToolSchema(BaseModel):
|
||||
@@ -27,12 +22,8 @@ class MySQLSearchTool(RagTool):
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
def __init__(self, table_name: str, **kwargs):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
kwargs["data_type"] = "mysql"
|
||||
kwargs["loader"] = MySQLLoader(config=dict(url=self.db_uri))
|
||||
self.add(table_name)
|
||||
self.add(table_name, data_type=DataType.MYSQL, metadata={"db_uri": self.db_uri})
|
||||
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
||||
self._generate_description()
|
||||
|
||||
@@ -46,6 +37,8 @@ class MySQLSearchTool(RagTool):
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -2,13 +2,8 @@ from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedPDFSearchToolSchema(BaseModel):
|
||||
@@ -41,15 +36,15 @@ class PDFSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, pdf: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(pdf, data_type=DataType.PDF_FILE)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
pdf: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if pdf is not None:
|
||||
self.add(pdf)
|
||||
return super()._run(query=query)
|
||||
return super()._run(query=query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import Any, Type
|
||||
|
||||
try:
|
||||
from embedchain.loaders.postgres import PostgresLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class PGSearchToolSchema(BaseModel):
|
||||
@@ -27,12 +22,8 @@ class PGSearchTool(RagTool):
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
def __init__(self, table_name: str, **kwargs):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
kwargs["data_type"] = "postgres"
|
||||
kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
|
||||
self.add(table_name)
|
||||
self.add(table_name, data_type=DataType.POSTGRES, metadata={"db_uri": self.db_uri})
|
||||
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
||||
self._generate_description()
|
||||
|
||||
@@ -46,6 +37,8 @@ class PGSearchTool(RagTool):
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit, **kwargs)
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
import portalocker
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def query(self, question: str) -> str:
|
||||
def query(
|
||||
self,
|
||||
question: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
"""Query the knowledge base with a question and return the answer."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -25,7 +30,12 @@ class Adapter(BaseModel, ABC):
|
||||
|
||||
class RagTool(BaseTool):
|
||||
class _AdapterPlaceholder(Adapter):
|
||||
def query(self, question: str) -> str:
|
||||
def query(
|
||||
self,
|
||||
question: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||
@@ -34,28 +44,149 @@ class RagTool(BaseTool):
|
||||
name: str = "Knowledge base"
|
||||
description: str = "A knowledge base that can be used to answer questions."
|
||||
summarize: bool = False
|
||||
similarity_threshold: float = 0.6
|
||||
limit: int = 5
|
||||
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
||||
config: dict[str, Any] | None = None
|
||||
config: Any | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_default_adapter(self):
|
||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||
try:
|
||||
from embedchain import App
|
||||
except ImportError:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
parsed_config = self._parse_config(self.config)
|
||||
|
||||
with portalocker.Lock("crewai-rag-tool.lock", timeout=10):
|
||||
app = App.from_config(config=self.config) if self.config else App()
|
||||
|
||||
self.adapter = EmbedchainAdapter(
|
||||
embedchain_app=app, summarize=self.summarize
|
||||
self.adapter = CrewAIRagAdapter(
|
||||
collection_name="rag_tool_collection",
|
||||
summarize=self.summarize,
|
||||
similarity_threshold=self.similarity_threshold,
|
||||
limit=self.limit,
|
||||
config=parsed_config,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _parse_config(self, config: Any) -> Any:
|
||||
"""Parse complex config format to extract provider-specific config.
|
||||
|
||||
Raises:
|
||||
ValueError: If the config format is invalid or uses unsupported providers.
|
||||
"""
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
if isinstance(config, dict) and "provider" in config:
|
||||
return config
|
||||
|
||||
if isinstance(config, dict):
|
||||
if "vectordb" in config:
|
||||
vectordb_config = config["vectordb"]
|
||||
if isinstance(vectordb_config, dict) and "provider" in vectordb_config:
|
||||
provider = vectordb_config["provider"]
|
||||
provider_config = vectordb_config.get("config", {})
|
||||
|
||||
supported_providers = ["chromadb", "qdrant"]
|
||||
if provider not in supported_providers:
|
||||
raise ValueError(
|
||||
f"Unsupported vector database provider: '{provider}'. "
|
||||
f"CrewAI RAG currently supports: {', '.join(supported_providers)}."
|
||||
)
|
||||
|
||||
embedding_config = config.get("embedding_model")
|
||||
embedding_function = None
|
||||
if embedding_config and isinstance(embedding_config, dict):
|
||||
embedding_function = self._create_embedding_function(
|
||||
embedding_config, provider
|
||||
)
|
||||
|
||||
return self._create_provider_config(
|
||||
provider, provider_config, embedding_function
|
||||
)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
embedding_config = config.get("embedding_model")
|
||||
embedding_function = None
|
||||
if embedding_config and isinstance(embedding_config, dict):
|
||||
embedding_function = self._create_embedding_function(
|
||||
embedding_config, "chromadb"
|
||||
)
|
||||
|
||||
return self._create_provider_config("chromadb", {}, embedding_function)
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _create_embedding_function(embedding_config: dict, provider: str) -> Any:
|
||||
"""Create embedding function for the specified vector database provider."""
|
||||
embedding_provider = embedding_config.get("provider")
|
||||
embedding_model_config = embedding_config.get("config", {}).copy()
|
||||
|
||||
if "model" in embedding_model_config:
|
||||
embedding_model_config["model_name"] = embedding_model_config.pop("model")
|
||||
|
||||
factory_config = {"provider": embedding_provider, **embedding_model_config}
|
||||
|
||||
if embedding_provider == "openai" and "api_key" not in factory_config:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
factory_config["api_key"] = api_key
|
||||
|
||||
print(f"Creating embedding function with config: {factory_config}")
|
||||
|
||||
if provider == "chromadb":
|
||||
embedding_func = get_embedding_function(factory_config)
|
||||
print(f"Created embedding function: {embedding_func}")
|
||||
print(f"Embedding function type: {type(embedding_func)}")
|
||||
return embedding_func
|
||||
|
||||
elif provider == "qdrant":
|
||||
chromadb_func = get_embedding_function(factory_config)
|
||||
|
||||
def qdrant_embed_fn(text: str) -> list[float]:
|
||||
"""Embed text using ChromaDB function and convert to list of floats for Qdrant.
|
||||
|
||||
Args:
|
||||
text: The input text to embed.
|
||||
|
||||
Returns:
|
||||
A list of floats representing the embedding.
|
||||
"""
|
||||
embeddings = chromadb_func([text])
|
||||
return embeddings[0] if embeddings and len(embeddings) > 0 else []
|
||||
|
||||
return cast(Any, qdrant_embed_fn)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _create_provider_config(
|
||||
provider: str, provider_config: dict, embedding_function: Any
|
||||
) -> Any:
|
||||
"""Create proper provider config object."""
|
||||
if provider == "chromadb":
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
|
||||
config_kwargs = {}
|
||||
if embedding_function:
|
||||
config_kwargs["embedding_function"] = embedding_function
|
||||
|
||||
config_kwargs.update(provider_config)
|
||||
|
||||
return ChromaDBConfig(**config_kwargs)
|
||||
|
||||
elif provider == "qdrant":
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
config_kwargs = {}
|
||||
if embedding_function:
|
||||
config_kwargs["embedding_function"] = embedding_function
|
||||
|
||||
config_kwargs.update(provider_config)
|
||||
|
||||
return QdrantConfig(**config_kwargs)
|
||||
|
||||
return None
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -66,5 +197,13 @@ class RagTool(BaseTool):
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
return f"Relevant Content:\n{self.adapter.query(query)}"
|
||||
threshold = (
|
||||
similarity_threshold
|
||||
if similarity_threshold is not None
|
||||
else self.similarity_threshold
|
||||
)
|
||||
result_limit = limit if limit is not None else self.limit
|
||||
return f"Relevant Content:\n{self.adapter.query(query, similarity_threshold=threshold, limit=result_limit)}"
|
||||
|
||||
@@ -39,7 +39,9 @@ class TXTSearchTool(RagTool):
|
||||
self,
|
||||
search_query: str,
|
||||
txt: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if txt is not None:
|
||||
self.add(txt)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedWebsiteSearchToolSchema(BaseModel):
|
||||
@@ -44,15 +39,15 @@ class WebsiteSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, website: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(website, data_type=DataType.WEB_PAGE)
|
||||
super().add(website, data_type=DataType.WEBSITE)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
website: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if website is not None:
|
||||
self.add(website)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -39,7 +39,9 @@ class XMLSearchTool(RagTool):
|
||||
self,
|
||||
search_query: str,
|
||||
xml: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if xml is not None:
|
||||
self.add(xml)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedYoutubeChannelSearchToolSchema(BaseModel):
|
||||
@@ -55,7 +50,9 @@ class YoutubeChannelSearchTool(RagTool):
|
||||
self,
|
||||
search_query: str,
|
||||
youtube_channel_handle: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if youtube_channel_handle is not None:
|
||||
self.add(youtube_channel_handle)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedYoutubeVideoSearchToolSchema(BaseModel):
|
||||
@@ -44,15 +40,15 @@ class YoutubeVideoSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, youtube_video_url: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
youtube_video_url: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if youtube_video_url is not None:
|
||||
self.add(youtube_video_url)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,43 +1,54 @@
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import cast
|
||||
from unittest import mock
|
||||
from pathlib import Path
|
||||
|
||||
from pytest import fixture
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@fixture(autouse=True)
|
||||
def mock_embedchain_db_uri():
|
||||
with NamedTemporaryFile() as tmp:
|
||||
uri = f"sqlite:///{tmp.name}"
|
||||
with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}):
|
||||
yield
|
||||
|
||||
|
||||
def test_custom_llm_and_embedder():
|
||||
def test_rag_tool_initialization():
|
||||
"""Test that RagTool initializes with CrewAI adapter by default."""
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="openai",
|
||||
config=dict(model="gpt-3.5-custom"),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="openai",
|
||||
config=dict(model="text-embedding-3-custom"),
|
||||
),
|
||||
)
|
||||
)
|
||||
tool = MyTool()
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, EmbedchainAdapter)
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
adapter = cast(EmbedchainAdapter, tool.adapter)
|
||||
assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom"
|
||||
assert (
|
||||
adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom"
|
||||
)
|
||||
adapter = cast(CrewAIRagAdapter, tool.adapter)
|
||||
assert adapter.collection_name == "rag_tool_collection"
|
||||
assert adapter._client is not None
|
||||
|
||||
|
||||
def test_rag_tool_add_and_query():
|
||||
"""Test adding content and querying with RagTool."""
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool()
|
||||
|
||||
tool.add("The sky is blue on a clear day.")
|
||||
tool.add("Machine learning is a subset of artificial intelligence.")
|
||||
|
||||
result = tool._run(query="What color is the sky?")
|
||||
assert "Relevant Content:" in result
|
||||
|
||||
result = tool._run(query="Tell me about machine learning")
|
||||
assert "Relevant Content:" in result
|
||||
|
||||
|
||||
def test_rag_tool_with_file():
|
||||
"""Test RagTool with file content."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Python is a programming language known for its simplicity.")
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool()
|
||||
tool.add(str(test_file))
|
||||
|
||||
result = tool._run(query="What is Python?")
|
||||
assert "Relevant Content:" in result
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import ANY, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools import (
|
||||
CodeDocsSearchTool,
|
||||
CSVSearchTool,
|
||||
@@ -49,7 +49,7 @@ def test_pdf_search_tool(mock_adapter):
|
||||
result = tool._run(query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
@@ -58,7 +58,7 @@ def test_pdf_search_tool(mock_adapter):
|
||||
result = tool._run(pdf="test.pdf", query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_txt_search_tool():
|
||||
@@ -82,7 +82,7 @@ def test_docx_search_tool(mock_adapter):
|
||||
result = tool._run(search_query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
@@ -91,7 +91,7 @@ def test_docx_search_tool(mock_adapter):
|
||||
result = tool._run(docx="test.docx", search_query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_json_search_tool():
|
||||
@@ -114,7 +114,7 @@ def test_xml_search_tool(mock_adapter):
|
||||
result = tool._run(search_query="test XML", xml="test.xml")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.xml")
|
||||
mock_adapter.query.assert_called_once_with("test XML")
|
||||
mock_adapter.query.assert_called_once_with("test XML", similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_csv_search_tool():
|
||||
@@ -153,8 +153,8 @@ def test_website_search_tool(mock_adapter):
|
||||
tool = WebsiteSearchTool(website=website, adapter=mock_adapter)
|
||||
result = tool._run(search_query=search_query)
|
||||
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?")
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?", similarity_threshold=0.6, limit=5)
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE)
|
||||
|
||||
assert "this is a test" in result.lower()
|
||||
|
||||
@@ -164,8 +164,8 @@ def test_website_search_tool(mock_adapter):
|
||||
tool = WebsiteSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(website=website, search_query=search_query)
|
||||
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?")
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?", similarity_threshold=0.6, limit=5)
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE)
|
||||
|
||||
assert "this is a test" in result.lower()
|
||||
|
||||
@@ -185,7 +185,7 @@ def test_youtube_video_search_tool(mock_adapter):
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
@@ -197,7 +197,7 @@ def test_youtube_video_search_tool(mock_adapter):
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_youtube_channel_search_tool(mock_adapter):
|
||||
@@ -213,7 +213,7 @@ def test_youtube_channel_search_tool(mock_adapter):
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
@@ -227,7 +227,7 @@ def test_youtube_channel_search_tool(mock_adapter):
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_code_docs_search_tool(mock_adapter):
|
||||
@@ -239,7 +239,7 @@ def test_code_docs_search_tool(mock_adapter):
|
||||
result = tool._run(search_query=search_query)
|
||||
assert "test documentation" in result
|
||||
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
@@ -248,7 +248,7 @@ def test_code_docs_search_tool(mock_adapter):
|
||||
result = tool._run(docs_url=docs_url, search_query=search_query)
|
||||
assert "test documentation" in result
|
||||
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_github_search_tool(mock_adapter):
|
||||
@@ -264,9 +264,11 @@ def test_github_search_tool(mock_adapter):
|
||||
result = tool._run(search_query="tell me about crewai repo")
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code", data_type="github", loader=ANY
|
||||
"https://github.com/crewai/crewai",
|
||||
data_type=DataType.GITHUB,
|
||||
metadata={"content_types": ["code"], "gh_token": "test_token"}
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||
|
||||
# ensure content types provided by run call is used
|
||||
mock_adapter.query.reset_mock()
|
||||
@@ -280,9 +282,11 @@ def test_github_search_tool(mock_adapter):
|
||||
)
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code,issue", data_type="github", loader=ANY
|
||||
"https://github.com/crewai/crewai",
|
||||
data_type=DataType.GITHUB,
|
||||
metadata={"content_types": ["code", "issue"], "gh_token": "test_token"}
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||
|
||||
# ensure default content types are used if not provided
|
||||
mock_adapter.query.reset_mock()
|
||||
@@ -295,9 +299,11 @@ def test_github_search_tool(mock_adapter):
|
||||
)
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code,repo,pr,issue", data_type="github", loader=ANY
|
||||
"https://github.com/crewai/crewai",
|
||||
data_type=DataType.GITHUB,
|
||||
metadata={"content_types": ["code", "repo", "pr", "issue"], "gh_token": "test_token"}
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||
|
||||
# ensure nothing is added if no repo is provided
|
||||
mock_adapter.query.reset_mock()
|
||||
@@ -306,4 +312,4 @@ def test_github_search_tool(mock_adapter):
|
||||
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
|
||||
result = tool._run(search_query="tell me about crewai repo")
|
||||
mock_adapter.add.assert_not_called()
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||
|
||||
Reference in New Issue
Block a user