From e29ca9ec282b9c20d0a8e5a969c33ebbedbd9d42 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 18 Sep 2025 19:02:22 -0400 Subject: [PATCH] feat: replace embedchain with native crewai adapter (#451) - Remove embedchain adapter; add crewai rag adapter and update all search tools - Add loaders: pdf, youtube (video & channel), github, docs site, mysql, postgresql - Add configurable similarity threshold, limit params, and embedding_model support - Improve chromadb compatibility (sanitize metadata, convert columns, fix chunking) - Fix xml encoding, Python 3.10 issues, and youtube url spoofing - Update crewai dependency and instructions; refresh uv.lock - Update tests for new rag adapter and search params --- .../adapters/crewai_rag_adapter.py | 215 ++++++++++++++++++ .../adapters/embedchain_adapter.py | 34 --- .../adapters/pdf_embedchain_adapter.py | 41 ---- src/crewai_tools/rag/chunkers/base_chunker.py | 10 +- src/crewai_tools/rag/data_types.py | 25 +- src/crewai_tools/rag/loaders/__init__.py | 6 + .../rag/loaders/docs_site_loader.py | 98 ++++++++ src/crewai_tools/rag/loaders/github_loader.py | 110 +++++++++ src/crewai_tools/rag/loaders/mysql_loader.py | 99 ++++++++ src/crewai_tools/rag/loaders/pdf_loader.py | 72 ++++++ .../rag/loaders/postgres_loader.py | 99 ++++++++ src/crewai_tools/rag/loaders/xml_loader.py | 2 +- .../rag/loaders/youtube_channel_loader.py | 141 ++++++++++++ .../rag/loaders/youtube_video_loader.py | 123 ++++++++++ src/crewai_tools/rag/misc.py | 25 ++ .../code_docs_search_tool.py | 12 +- .../tools/csv_search_tool/csv_search_tool.py | 13 +- .../directory_search_tool.py | 18 +- .../docx_search_tool/docx_search_tool.py | 12 +- .../github_search_tool/github_search_tool.py | 26 +-- .../json_search_tool/json_search_tool.py | 4 +- .../tools/mdx_search_tool/mdx_search_tool.py | 12 +- .../mysql_search_tool/mysql_search_tool.py | 17 +- .../tools/pdf_search_tool/pdf_search_tool.py | 13 +- .../tools/pg_search_tool/pg_search_tool.py | 17 +- src/crewai_tools/tools/rag/rag_tool.py | 175 ++++++++++++-- .../tools/txt_search_tool/txt_search_tool.py | 4 +- .../website_search/website_search_tool.py | 15 +- .../tools/xml_search_tool/xml_search_tool.py | 4 +- .../youtube_channel_search_tool.py | 11 +- .../youtube_video_search_tool.py | 12 +- tests/tools/rag/rag_tool_test.py | 75 +++--- tests/tools/test_search_tools.py | 54 +++-- 33 files changed, 1317 insertions(+), 277 deletions(-) create mode 100644 src/crewai_tools/adapters/crewai_rag_adapter.py delete mode 100644 src/crewai_tools/adapters/embedchain_adapter.py delete mode 100644 src/crewai_tools/adapters/pdf_embedchain_adapter.py create mode 100644 src/crewai_tools/rag/loaders/docs_site_loader.py create mode 100644 src/crewai_tools/rag/loaders/github_loader.py create mode 100644 src/crewai_tools/rag/loaders/mysql_loader.py create mode 100644 src/crewai_tools/rag/loaders/pdf_loader.py create mode 100644 src/crewai_tools/rag/loaders/postgres_loader.py create mode 100644 src/crewai_tools/rag/loaders/youtube_channel_loader.py create mode 100644 src/crewai_tools/rag/loaders/youtube_video_loader.py diff --git a/src/crewai_tools/adapters/crewai_rag_adapter.py b/src/crewai_tools/adapters/crewai_rag_adapter.py new file mode 100644 index 000000000..c2142ad4b --- /dev/null +++ b/src/crewai_tools/adapters/crewai_rag_adapter.py @@ -0,0 +1,215 @@ +"""Adapter for CrewAI's native RAG system.""" + +from typing import Any, TypedDict, TypeAlias +from typing_extensions import Unpack +from pathlib import Path +import hashlib + +from pydantic import Field, PrivateAttr +from crewai.rag.config.utils import get_rag_client +from crewai.rag.config.types import RagConfigType +from crewai.rag.types import BaseRecord, SearchResult +from crewai.rag.core.base_client import BaseClient +from crewai.rag.factory import create_client + +from crewai_tools.tools.rag.rag_tool import Adapter +from crewai_tools.rag.data_types import DataType +from crewai_tools.rag.misc import sanitize_metadata_for_chromadb +from crewai_tools.rag.chunkers.base_chunker import BaseChunker + +ContentItem: TypeAlias = str | Path | dict[str, Any] + +class AddDocumentParams(TypedDict, total=False): + """Parameters for adding documents to the RAG system.""" + data_type: DataType + metadata: dict[str, Any] + website: str + url: str + file_path: str | Path + github_url: str + youtube_url: str + directory_path: str | Path + + +class CrewAIRagAdapter(Adapter): + """Adapter that uses CrewAI's native RAG system. + + Supports custom vector database configuration through the config parameter. + """ + + collection_name: str = "default" + summarize: bool = False + similarity_threshold: float = 0.6 + limit: int = 5 + config: RagConfigType | None = None + _client: BaseClient | None = PrivateAttr(default=None) + + def model_post_init(self, __context: Any) -> None: + """Initialize the CrewAI RAG client after model initialization.""" + if self.config is not None: + self._client = create_client(self.config) + else: + self._client = get_rag_client() + self._client.get_or_create_collection(collection_name=self.collection_name) + + def query(self, question: str, similarity_threshold: float | None = None, limit: int | None = None) -> str: + """Query the knowledge base with a question. + + Args: + question: The question to ask + similarity_threshold: Minimum similarity score for results (default: 0.6) + limit: Maximum number of results to return (default: 5) + + Returns: + Relevant content from the knowledge base + """ + search_limit = limit if limit is not None else self.limit + search_threshold = similarity_threshold if similarity_threshold is not None else self.similarity_threshold + + results: list[SearchResult] = self._client.search( + collection_name=self.collection_name, + query=question, + limit=search_limit, + score_threshold=search_threshold + ) + + if not results: + return "No relevant content found." + + contents: list[str] = [] + for result in results: + content: str = result.get("content", "") + if content: + contents.append(content) + + return "\n\n".join(contents) + + def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None: + """Add content to the knowledge base. + + This method handles various input types and converts them to documents + for the vector database. It supports the data_type parameter for + compatibility with existing tools. + + Args: + *args: Content items to add (strings, paths, or document dicts) + **kwargs: Additional parameters including data_type, metadata, etc. + """ + from crewai_tools.rag.data_types import DataTypes, DataType + from crewai_tools.rag.source_content import SourceContent + from crewai_tools.rag.base_loader import LoaderResult + import os + + documents: list[BaseRecord] = [] + data_type: DataType | None = kwargs.get("data_type") + base_metadata: dict[str, Any] = kwargs.get("metadata", {}) + + for arg in args: + source_ref: str + if isinstance(arg, dict): + source_ref = str(arg.get("source", arg.get("content", ""))) + else: + source_ref = str(arg) + + if not data_type: + data_type = DataTypes.from_content(source_ref) + + if data_type == DataType.DIRECTORY: + if not os.path.isdir(source_ref): + raise ValueError(f"Directory does not exist: {source_ref}") + + # Define binary and non-text file extensions to skip + binary_extensions = {'.pyc', '.pyo', '.png', '.jpg', '.jpeg', '.gif', + '.bmp', '.ico', '.svg', '.webp', '.pdf', '.zip', + '.tar', '.gz', '.bz2', '.7z', '.rar', '.exe', + '.dll', '.so', '.dylib', '.bin', '.dat', '.db', + '.sqlite', '.class', '.jar', '.war', '.ear'} + + for root, dirs, files in os.walk(source_ref): + dirs[:] = [d for d in dirs if not d.startswith('.')] + + for filename in files: + if filename.startswith('.'): + continue + + # Skip binary files based on extension + file_ext = os.path.splitext(filename)[1].lower() + if file_ext in binary_extensions: + continue + + # Skip __pycache__ directories + if '__pycache__' in root: + continue + + file_path: str = os.path.join(root, filename) + try: + file_data_type: DataType = DataTypes.from_content(file_path) + file_loader = file_data_type.get_loader() + file_chunker = file_data_type.get_chunker() + + file_source = SourceContent(file_path) + file_result: LoaderResult = file_loader.load(file_source) + + file_chunks = file_chunker.chunk(file_result.content) + + for chunk_idx, file_chunk in enumerate(file_chunks): + file_metadata: dict[str, Any] = base_metadata.copy() + file_metadata.update(file_result.metadata) + file_metadata["data_type"] = str(file_data_type) + file_metadata["file_path"] = file_path + file_metadata["chunk_index"] = chunk_idx + file_metadata["total_chunks"] = len(file_chunks) + + if isinstance(arg, dict): + file_metadata.update(arg.get("metadata", {})) + + chunk_id = hashlib.sha256(f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()).hexdigest() + + documents.append({ + "doc_id": chunk_id, + "content": file_chunk, + "metadata": sanitize_metadata_for_chromadb(file_metadata) + }) + except Exception: + # Silently skip files that can't be processed + continue + else: + metadata: dict[str, Any] = base_metadata.copy() + + if data_type in [DataType.PDF_FILE, DataType.TEXT_FILE, DataType.DOCX, + DataType.CSV, DataType.JSON, DataType.XML, DataType.MDX]: + if not os.path.isfile(source_ref): + raise FileNotFoundError(f"File does not exist: {source_ref}") + + loader = data_type.get_loader() + chunker = data_type.get_chunker() + + source_content = SourceContent(source_ref) + loader_result: LoaderResult = loader.load(source_content) + + chunks = chunker.chunk(loader_result.content) + + for i, chunk in enumerate(chunks): + chunk_metadata: dict[str, Any] = metadata.copy() + chunk_metadata.update(loader_result.metadata) + chunk_metadata["data_type"] = str(data_type) + chunk_metadata["chunk_index"] = i + chunk_metadata["total_chunks"] = len(chunks) + chunk_metadata["source"] = source_ref + + if isinstance(arg, dict): + chunk_metadata.update(arg.get("metadata", {})) + + chunk_id = hashlib.sha256(f"{loader_result.doc_id}_{i}_{chunk}".encode()).hexdigest() + + documents.append({ + "doc_id": chunk_id, + "content": chunk, + "metadata": sanitize_metadata_for_chromadb(chunk_metadata) + }) + + if documents: + self._client.add_documents( + collection_name=self.collection_name, + documents=documents + ) \ No newline at end of file diff --git a/src/crewai_tools/adapters/embedchain_adapter.py b/src/crewai_tools/adapters/embedchain_adapter.py deleted file mode 100644 index 1e7b83c0b..000000000 --- a/src/crewai_tools/adapters/embedchain_adapter.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Any - -from crewai_tools.tools.rag.rag_tool import Adapter - -try: - from embedchain import App - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False - - -class EmbedchainAdapter(Adapter): - embedchain_app: Any # Will be App when embedchain is available - summarize: bool = False - - def __init__(self, **data): - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") - super().__init__(**data) - - def query(self, question: str) -> str: - result, sources = self.embedchain_app.query( - question, citations=True, dry_run=(not self.summarize) - ) - if self.summarize: - return result - return "\n\n".join([source[0] for source in sources]) - - def add( - self, - *args: Any, - **kwargs: Any, - ) -> None: - self.embedchain_app.add(*args, **kwargs) diff --git a/src/crewai_tools/adapters/pdf_embedchain_adapter.py b/src/crewai_tools/adapters/pdf_embedchain_adapter.py deleted file mode 100644 index aa682c84f..000000000 --- a/src/crewai_tools/adapters/pdf_embedchain_adapter.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Any, Optional - -from crewai_tools.tools.rag.rag_tool import Adapter - -try: - from embedchain import App - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False - - -class PDFEmbedchainAdapter(Adapter): - embedchain_app: Any # Will be App when embedchain is available - summarize: bool = False - src: Optional[str] = None - - def __init__(self, **data): - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") - super().__init__(**data) - - def query(self, question: str) -> str: - where = ( - {"app_id": self.embedchain_app.config.id, "source": self.src} - if self.src - else None - ) - result, sources = self.embedchain_app.query( - question, citations=True, dry_run=(not self.summarize), where=where - ) - if self.summarize: - return result - return "\n\n".join([source[0] for source in sources]) - - def add( - self, - *args: Any, - **kwargs: Any, - ) -> None: - self.src = args[0] if args else None - self.embedchain_app.add(*args, **kwargs) diff --git a/src/crewai_tools/rag/chunkers/base_chunker.py b/src/crewai_tools/rag/chunkers/base_chunker.py index deafbfc7a..cadf66f16 100644 --- a/src/crewai_tools/rag/chunkers/base_chunker.py +++ b/src/crewai_tools/rag/chunkers/base_chunker.py @@ -112,7 +112,10 @@ class RecursiveCharacterTextSplitter: if separator == "": doc = "".join(current_doc) else: - doc = separator.join(current_doc) + if self._keep_separator and separator == " ": + doc = "".join(current_doc) + else: + doc = separator.join(current_doc) if doc: docs.append(doc) @@ -133,7 +136,10 @@ class RecursiveCharacterTextSplitter: if separator == "": doc = "".join(current_doc) else: - doc = separator.join(current_doc) + if self._keep_separator and separator == " ": + doc = "".join(current_doc) + else: + doc = separator.join(current_doc) if doc: docs.append(doc) diff --git a/src/crewai_tools/rag/data_types.py b/src/crewai_tools/rag/data_types.py index d2d265cce..1e6f0d8c6 100644 --- a/src/crewai_tools/rag/data_types.py +++ b/src/crewai_tools/rag/data_types.py @@ -25,6 +25,8 @@ class DataType(str, Enum): # Web types WEBSITE = "website" DOCS_SITE = "docs_site" + YOUTUBE_VIDEO = "youtube_video" + YOUTUBE_CHANNEL = "youtube_channel" # Raw types TEXT = "text" @@ -34,6 +36,7 @@ class DataType(str, Enum): from importlib import import_module chunkers = { + DataType.PDF_FILE: ("text_chunker", "TextChunker"), DataType.TEXT_FILE: ("text_chunker", "TextChunker"), DataType.TEXT: ("text_chunker", "TextChunker"), DataType.DOCX: ("text_chunker", "DocxChunker"), @@ -45,9 +48,18 @@ class DataType(str, Enum): DataType.XML: ("structured_chunker", "XmlChunker"), DataType.WEBSITE: ("web_chunker", "WebsiteChunker"), + DataType.DIRECTORY: ("text_chunker", "TextChunker"), + DataType.YOUTUBE_VIDEO: ("text_chunker", "TextChunker"), + DataType.YOUTUBE_CHANNEL: ("text_chunker", "TextChunker"), + DataType.GITHUB: ("text_chunker", "TextChunker"), + DataType.DOCS_SITE: ("text_chunker", "TextChunker"), + DataType.MYSQL: ("text_chunker", "TextChunker"), + DataType.POSTGRES: ("text_chunker", "TextChunker"), } - module_name, class_name = chunkers.get(self, ("default_chunker", "DefaultChunker")) + if self not in chunkers: + raise ValueError(f"No chunker defined for {self}") + module_name, class_name = chunkers[self] module_path = f"crewai_tools.rag.chunkers.{module_name}" try: @@ -60,6 +72,7 @@ class DataType(str, Enum): from importlib import import_module loaders = { + DataType.PDF_FILE: ("pdf_loader", "PDFLoader"), DataType.TEXT_FILE: ("text_loader", "TextFileLoader"), DataType.TEXT: ("text_loader", "TextLoader"), DataType.XML: ("xml_loader", "XMLLoader"), @@ -69,9 +82,17 @@ class DataType(str, Enum): DataType.DOCX: ("docx_loader", "DOCXLoader"), DataType.CSV: ("csv_loader", "CSVLoader"), DataType.DIRECTORY: ("directory_loader", "DirectoryLoader"), + DataType.YOUTUBE_VIDEO: ("youtube_video_loader", "YoutubeVideoLoader"), + DataType.YOUTUBE_CHANNEL: ("youtube_channel_loader", "YoutubeChannelLoader"), + DataType.GITHUB: ("github_loader", "GithubLoader"), + DataType.DOCS_SITE: ("docs_site_loader", "DocsSiteLoader"), + DataType.MYSQL: ("mysql_loader", "MySQLLoader"), + DataType.POSTGRES: ("postgres_loader", "PostgresLoader"), } - module_name, class_name = loaders.get(self, ("text_loader", "TextLoader")) + if self not in loaders: + raise ValueError(f"No loader defined for {self}") + module_name, class_name = loaders[self] module_path = f"crewai_tools.rag.loaders.{module_name}" try: module = import_module(module_path) diff --git a/src/crewai_tools/rag/loaders/__init__.py b/src/crewai_tools/rag/loaders/__init__.py index 503651468..dc7424833 100644 --- a/src/crewai_tools/rag/loaders/__init__.py +++ b/src/crewai_tools/rag/loaders/__init__.py @@ -6,6 +6,9 @@ from crewai_tools.rag.loaders.json_loader import JSONLoader from crewai_tools.rag.loaders.docx_loader import DOCXLoader from crewai_tools.rag.loaders.csv_loader import CSVLoader from crewai_tools.rag.loaders.directory_loader import DirectoryLoader +from crewai_tools.rag.loaders.pdf_loader import PDFLoader +from crewai_tools.rag.loaders.youtube_video_loader import YoutubeVideoLoader +from crewai_tools.rag.loaders.youtube_channel_loader import YoutubeChannelLoader __all__ = [ "TextFileLoader", @@ -17,4 +20,7 @@ __all__ = [ "DOCXLoader", "CSVLoader", "DirectoryLoader", + "PDFLoader", + "YoutubeVideoLoader", + "YoutubeChannelLoader", ] diff --git a/src/crewai_tools/rag/loaders/docs_site_loader.py b/src/crewai_tools/rag/loaders/docs_site_loader.py new file mode 100644 index 000000000..b87ebc419 --- /dev/null +++ b/src/crewai_tools/rag/loaders/docs_site_loader.py @@ -0,0 +1,98 @@ +"""Documentation site loader.""" + +from typing import Any +from urllib.parse import urljoin, urlparse + +import requests +from bs4 import BeautifulSoup + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class DocsSiteLoader(BaseLoader): + """Loader for documentation websites.""" + + def load(self, source: SourceContent, **kwargs) -> LoaderResult: + """Load content from a documentation site. + + Args: + source: Documentation site URL + **kwargs: Additional arguments + + Returns: + LoaderResult with documentation content + """ + docs_url = source.source + + try: + response = requests.get(docs_url, timeout=30) + response.raise_for_status() + except requests.RequestException as e: + raise ValueError(f"Unable to fetch documentation from {docs_url}: {e}") + + soup = BeautifulSoup(response.text, "html.parser") + + for script in soup(["script", "style"]): + script.decompose() + + title = soup.find("title") + title_text = title.get_text(strip=True) if title else "Documentation" + + main_content = None + for selector in ["main", "article", '[role="main"]', ".content", "#content", ".documentation"]: + main_content = soup.select_one(selector) + if main_content: + break + + if not main_content: + main_content = soup.find("body") + + if not main_content: + raise ValueError(f"Unable to extract content from documentation site: {docs_url}") + + text_parts = [f"Title: {title_text}", ""] + + headings = main_content.find_all(["h1", "h2", "h3"]) + if headings: + text_parts.append("Table of Contents:") + for heading in headings[:15]: + level = int(heading.name[1]) + indent = " " * (level - 1) + text_parts.append(f"{indent}- {heading.get_text(strip=True)}") + text_parts.append("") + + text = main_content.get_text(separator="\n", strip=True) + lines = [line.strip() for line in text.split("\n") if line.strip()] + text_parts.extend(lines) + + nav_links = [] + for nav_selector in ["nav", ".sidebar", ".toc", ".navigation"]: + nav = soup.select_one(nav_selector) + if nav: + links = nav.find_all("a", href=True) + for link in links[:20]: + href = link["href"] + if not href.startswith(("http://", "https://", "mailto:", "#")): + full_url = urljoin(docs_url, href) + nav_links.append(f"- {link.get_text(strip=True)}: {full_url}") + + if nav_links: + text_parts.append("") + text_parts.append("Related documentation pages:") + text_parts.extend(nav_links[:10]) + + content = "\n".join(text_parts) + + if len(content) > 100000: + content = content[:100000] + "\n\n[Content truncated...]" + + return LoaderResult( + content=content, + metadata={ + "source": docs_url, + "title": title_text, + "domain": urlparse(docs_url).netloc + }, + doc_id=self.generate_doc_id(source_ref=docs_url, content=content) + ) \ No newline at end of file diff --git a/src/crewai_tools/rag/loaders/github_loader.py b/src/crewai_tools/rag/loaders/github_loader.py new file mode 100644 index 000000000..b033c2071 --- /dev/null +++ b/src/crewai_tools/rag/loaders/github_loader.py @@ -0,0 +1,110 @@ +"""GitHub repository content loader.""" + +from typing import Any + +from github import Github, GithubException + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class GithubLoader(BaseLoader): + """Loader for GitHub repository content.""" + + def load(self, source: SourceContent, **kwargs) -> LoaderResult: + """Load content from a GitHub repository. + + Args: + source: GitHub repository URL + **kwargs: Additional arguments including gh_token and content_types + + Returns: + LoaderResult with repository content + """ + metadata = kwargs.get("metadata", {}) + gh_token = metadata.get("gh_token") + content_types = metadata.get("content_types", ["code", "repo"]) + + repo_url = source.source + if not repo_url.startswith("https://github.com/"): + raise ValueError(f"Invalid GitHub URL: {repo_url}") + + parts = repo_url.replace("https://github.com/", "").strip("/").split("/") + if len(parts) < 2: + raise ValueError(f"Invalid GitHub repository URL: {repo_url}") + + repo_name = f"{parts[0]}/{parts[1]}" + + g = Github(gh_token) if gh_token else Github() + + try: + repo = g.get_repo(repo_name) + except GithubException as e: + raise ValueError(f"Unable to access repository {repo_name}: {e}") + + all_content = [] + + if "repo" in content_types: + all_content.append(f"Repository: {repo.full_name}") + all_content.append(f"Description: {repo.description or 'No description'}") + all_content.append(f"Language: {repo.language or 'Not specified'}") + all_content.append(f"Stars: {repo.stargazers_count}") + all_content.append(f"Forks: {repo.forks_count}") + all_content.append("") + + if "code" in content_types: + try: + readme = repo.get_readme() + all_content.append("README:") + all_content.append(readme.decoded_content.decode("utf-8", errors="ignore")) + all_content.append("") + except GithubException: + pass + + try: + contents = repo.get_contents("") + if isinstance(contents, list): + all_content.append("Repository structure:") + for content_file in contents[:20]: + all_content.append(f"- {content_file.path} ({content_file.type})") + all_content.append("") + except GithubException: + pass + + if "pr" in content_types: + prs = repo.get_pulls(state="open") + pr_list = list(prs[:5]) + if pr_list: + all_content.append("Recent Pull Requests:") + for pr in pr_list: + all_content.append(f"- PR #{pr.number}: {pr.title}") + if pr.body: + body_preview = pr.body[:200].replace("\n", " ") + all_content.append(f" {body_preview}") + all_content.append("") + + if "issue" in content_types: + issues = repo.get_issues(state="open") + issue_list = [i for i in list(issues[:10]) if not i.pull_request][:5] + if issue_list: + all_content.append("Recent Issues:") + for issue in issue_list: + all_content.append(f"- Issue #{issue.number}: {issue.title}") + if issue.body: + body_preview = issue.body[:200].replace("\n", " ") + all_content.append(f" {body_preview}") + all_content.append("") + + if not all_content: + raise ValueError(f"No content could be loaded from repository: {repo_url}") + + content = "\n".join(all_content) + return LoaderResult( + content=content, + metadata={ + "source": repo_url, + "repo": repo_name, + "content_types": content_types + }, + doc_id=self.generate_doc_id(source_ref=repo_url, content=content) + ) \ No newline at end of file diff --git a/src/crewai_tools/rag/loaders/mysql_loader.py b/src/crewai_tools/rag/loaders/mysql_loader.py new file mode 100644 index 000000000..79a95e678 --- /dev/null +++ b/src/crewai_tools/rag/loaders/mysql_loader.py @@ -0,0 +1,99 @@ +"""MySQL database loader.""" + +from typing import Any +from urllib.parse import urlparse + +import pymysql + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class MySQLLoader(BaseLoader): + """Loader for MySQL database content.""" + + def load(self, source: SourceContent, **kwargs) -> LoaderResult: + """Load content from a MySQL database table. + + Args: + source: SQL query (e.g., "SELECT * FROM table_name") + **kwargs: Additional arguments including db_uri + + Returns: + LoaderResult with database content + """ + metadata = kwargs.get("metadata", {}) + db_uri = metadata.get("db_uri") + + if not db_uri: + raise ValueError("Database URI is required for MySQL loader") + + query = source.source + + parsed = urlparse(db_uri) + if parsed.scheme not in ["mysql", "mysql+pymysql"]: + raise ValueError(f"Invalid MySQL URI scheme: {parsed.scheme}") + + connection_params = { + "host": parsed.hostname or "localhost", + "port": parsed.port or 3306, + "user": parsed.username, + "password": parsed.password, + "database": parsed.path.lstrip("/") if parsed.path else None, + "charset": "utf8mb4", + "cursorclass": pymysql.cursors.DictCursor + } + + if not connection_params["database"]: + raise ValueError("Database name is required in the URI") + + try: + connection = pymysql.connect(**connection_params) + try: + with connection.cursor() as cursor: + cursor.execute(query) + rows = cursor.fetchall() + + if not rows: + content = "No data found in the table" + return LoaderResult( + content=content, + metadata={"source": query, "row_count": 0}, + doc_id=self.generate_doc_id(source_ref=query, content=content) + ) + + text_parts = [] + + columns = list(rows[0].keys()) + text_parts.append(f"Columns: {', '.join(columns)}") + text_parts.append(f"Total rows: {len(rows)}") + text_parts.append("") + + for i, row in enumerate(rows, 1): + text_parts.append(f"Row {i}:") + for col, val in row.items(): + if val is not None: + text_parts.append(f" {col}: {val}") + text_parts.append("") + + content = "\n".join(text_parts) + + if len(content) > 100000: + content = content[:100000] + "\n\n[Content truncated...]" + + return LoaderResult( + content=content, + metadata={ + "source": query, + "database": connection_params["database"], + "row_count": len(rows), + "columns": columns + }, + doc_id=self.generate_doc_id(source_ref=query, content=content) + ) + finally: + connection.close() + except pymysql.Error as e: + raise ValueError(f"MySQL database error: {e}") + except Exception as e: + raise ValueError(f"Failed to load data from MySQL: {e}") \ No newline at end of file diff --git a/src/crewai_tools/rag/loaders/pdf_loader.py b/src/crewai_tools/rag/loaders/pdf_loader.py new file mode 100644 index 000000000..ed1dbfbfe --- /dev/null +++ b/src/crewai_tools/rag/loaders/pdf_loader.py @@ -0,0 +1,72 @@ +"""PDF loader for extracting text from PDF files.""" + +import os +from pathlib import Path +from typing import Any + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class PDFLoader(BaseLoader): + """Loader for PDF files.""" + + def load(self, source: SourceContent, **kwargs) -> LoaderResult: + """Load and extract text from a PDF file. + + Args: + source: The source content containing the PDF file path + + Returns: + LoaderResult with extracted text content + + Raises: + FileNotFoundError: If the PDF file doesn't exist + ImportError: If required PDF libraries aren't installed + """ + try: + import pypdf + except ImportError: + try: + import PyPDF2 as pypdf + except ImportError: + raise ImportError( + "PDF support requires pypdf or PyPDF2. " + "Install with: uv add pypdf" + ) + + file_path = source.source + + if not os.path.isfile(file_path): + raise FileNotFoundError(f"PDF file not found: {file_path}") + + text_content = [] + metadata: dict[str, Any] = { + "source": str(file_path), + "file_name": Path(file_path).name, + "file_type": "pdf" + } + + try: + with open(file_path, 'rb') as file: + pdf_reader = pypdf.PdfReader(file) + metadata["num_pages"] = len(pdf_reader.pages) + + for page_num, page in enumerate(pdf_reader.pages, 1): + page_text = page.extract_text() + if page_text.strip(): + text_content.append(f"Page {page_num}:\n{page_text}") + except Exception as e: + raise ValueError(f"Error reading PDF file {file_path}: {str(e)}") + + if not text_content: + content = f"[PDF file with no extractable text: {Path(file_path).name}]" + else: + content = "\n\n".join(text_content) + + return LoaderResult( + content=content, + source=str(file_path), + metadata=metadata, + doc_id=self.generate_doc_id(source_ref=str(file_path), content=content) + ) \ No newline at end of file diff --git a/src/crewai_tools/rag/loaders/postgres_loader.py b/src/crewai_tools/rag/loaders/postgres_loader.py new file mode 100644 index 000000000..131dbdc3f --- /dev/null +++ b/src/crewai_tools/rag/loaders/postgres_loader.py @@ -0,0 +1,99 @@ +"""PostgreSQL database loader.""" + +from typing import Any +from urllib.parse import urlparse + +import psycopg2 +from psycopg2.extras import RealDictCursor + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class PostgresLoader(BaseLoader): + """Loader for PostgreSQL database content.""" + + def load(self, source: SourceContent, **kwargs) -> LoaderResult: + """Load content from a PostgreSQL database table. + + Args: + source: SQL query (e.g., "SELECT * FROM table_name") + **kwargs: Additional arguments including db_uri + + Returns: + LoaderResult with database content + """ + metadata = kwargs.get("metadata", {}) + db_uri = metadata.get("db_uri") + + if not db_uri: + raise ValueError("Database URI is required for PostgreSQL loader") + + query = source.source + + parsed = urlparse(db_uri) + if parsed.scheme not in ["postgresql", "postgres", "postgresql+psycopg2"]: + raise ValueError(f"Invalid PostgreSQL URI scheme: {parsed.scheme}") + + connection_params = { + "host": parsed.hostname or "localhost", + "port": parsed.port or 5432, + "user": parsed.username, + "password": parsed.password, + "database": parsed.path.lstrip("/") if parsed.path else None, + "cursor_factory": RealDictCursor + } + + if not connection_params["database"]: + raise ValueError("Database name is required in the URI") + + try: + connection = psycopg2.connect(**connection_params) + try: + with connection.cursor() as cursor: + cursor.execute(query) + rows = cursor.fetchall() + + if not rows: + content = "No data found in the table" + return LoaderResult( + content=content, + metadata={"source": query, "row_count": 0}, + doc_id=self.generate_doc_id(source_ref=query, content=content) + ) + + text_parts = [] + + columns = list(rows[0].keys()) + text_parts.append(f"Columns: {', '.join(columns)}") + text_parts.append(f"Total rows: {len(rows)}") + text_parts.append("") + + for i, row in enumerate(rows, 1): + text_parts.append(f"Row {i}:") + for col, val in row.items(): + if val is not None: + text_parts.append(f" {col}: {val}") + text_parts.append("") + + content = "\n".join(text_parts) + + if len(content) > 100000: + content = content[:100000] + "\n\n[Content truncated...]" + + return LoaderResult( + content=content, + metadata={ + "source": query, + "database": connection_params["database"], + "row_count": len(rows), + "columns": columns + }, + doc_id=self.generate_doc_id(source_ref=query, content=content) + ) + finally: + connection.close() + except psycopg2.Error as e: + raise ValueError(f"PostgreSQL database error: {e}") + except Exception as e: + raise ValueError(f"Failed to load data from PostgreSQL: {e}") \ No newline at end of file diff --git a/src/crewai_tools/rag/loaders/xml_loader.py b/src/crewai_tools/rag/loaders/xml_loader.py index ffafdb9d9..30c949932 100644 --- a/src/crewai_tools/rag/loaders/xml_loader.py +++ b/src/crewai_tools/rag/loaders/xml_loader.py @@ -11,7 +11,7 @@ class XMLLoader(BaseLoader): if source_content.is_url(): content = self._load_from_url(source_ref, kwargs) - elif os.path.exists(source_ref): + elif source_content.path_exists(): content = self._load_from_file(source_ref) return self._parse_xml(content, source_ref) diff --git a/src/crewai_tools/rag/loaders/youtube_channel_loader.py b/src/crewai_tools/rag/loaders/youtube_channel_loader.py new file mode 100644 index 000000000..3a62c5146 --- /dev/null +++ b/src/crewai_tools/rag/loaders/youtube_channel_loader.py @@ -0,0 +1,141 @@ +"""YouTube channel loader for extracting content from YouTube channels.""" + +import re +from typing import Any + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class YoutubeChannelLoader(BaseLoader): + """Loader for YouTube channels.""" + + def load(self, source: SourceContent, **kwargs) -> LoaderResult: + """Load and extract content from a YouTube channel. + + Args: + source: The source content containing the YouTube channel URL + + Returns: + LoaderResult with channel content + + Raises: + ImportError: If required YouTube libraries aren't installed + ValueError: If the URL is not a valid YouTube channel URL + """ + try: + from pytube import Channel + except ImportError: + raise ImportError( + "YouTube channel support requires pytube. " + "Install with: uv add pytube" + ) + + channel_url = source.source + + if not any(pattern in channel_url for pattern in ['youtube.com/channel/', 'youtube.com/c/', 'youtube.com/@', 'youtube.com/user/']): + raise ValueError(f"Invalid YouTube channel URL: {channel_url}") + + metadata: dict[str, Any] = { + "source": channel_url, + "data_type": "youtube_channel" + } + + try: + channel = Channel(channel_url) + + metadata["channel_name"] = channel.channel_name + metadata["channel_id"] = channel.channel_id + + max_videos = kwargs.get('max_videos', 10) + video_urls = list(channel.video_urls)[:max_videos] + metadata["num_videos_loaded"] = len(video_urls) + metadata["total_videos"] = len(list(channel.video_urls)) + + content_parts = [ + f"YouTube Channel: {channel.channel_name}", + f"Channel ID: {channel.channel_id}", + f"Total Videos: {metadata['total_videos']}", + f"Videos Loaded: {metadata['num_videos_loaded']}", + "\n--- Video Summaries ---\n" + ] + + try: + from youtube_transcript_api import YouTubeTranscriptApi + from pytube import YouTube + + for i, video_url in enumerate(video_urls, 1): + try: + video_id = self._extract_video_id(video_url) + if not video_id: + continue + yt = YouTube(video_url) + title = yt.title or f"Video {i}" + description = yt.description[:200] if yt.description else "No description" + + content_parts.append(f"\n{i}. {title}") + content_parts.append(f" URL: {video_url}") + content_parts.append(f" Description: {description}...") + + try: + api = YouTubeTranscriptApi() + transcript_list = api.list(video_id) + transcript = None + + try: + transcript = transcript_list.find_transcript(['en']) + except: + try: + transcript = transcript_list.find_generated_transcript(['en']) + except: + transcript = next(iter(transcript_list), None) + + if transcript: + transcript_data = transcript.fetch() + text_parts = [] + char_count = 0 + for entry in transcript_data: + text = entry.text.strip() if hasattr(entry, 'text') else '' + if text: + text_parts.append(text) + char_count += len(text) + if char_count > 500: + break + + if text_parts: + preview = ' '.join(text_parts)[:500] + content_parts.append(f" Transcript Preview: {preview}...") + except: + content_parts.append(" Transcript: Not available") + + except Exception as e: + content_parts.append(f"\n{i}. Error loading video: {str(e)}") + + except ImportError: + for i, video_url in enumerate(video_urls, 1): + content_parts.append(f"\n{i}. {video_url}") + + content = '\n'.join(content_parts) + + except Exception as e: + raise ValueError(f"Unable to load YouTube channel {channel_url}: {str(e)}") from e + + return LoaderResult( + content=content, + source=channel_url, + metadata=metadata, + doc_id=self.generate_doc_id(source_ref=channel_url, content=content) + ) + + def _extract_video_id(self, url: str) -> str | None: + """Extract video ID from YouTube URL.""" + patterns = [ + r'(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)', + ] + + for pattern in patterns: + match = re.search(pattern, url) + if match: + return match.group(1) + + return None \ No newline at end of file diff --git a/src/crewai_tools/rag/loaders/youtube_video_loader.py b/src/crewai_tools/rag/loaders/youtube_video_loader.py new file mode 100644 index 000000000..6e0fd39e8 --- /dev/null +++ b/src/crewai_tools/rag/loaders/youtube_video_loader.py @@ -0,0 +1,123 @@ +"""YouTube video loader for extracting transcripts from YouTube videos.""" + +import re +from typing import Any +from urllib.parse import urlparse, parse_qs + +from crewai_tools.rag.base_loader import BaseLoader, LoaderResult +from crewai_tools.rag.source_content import SourceContent + + +class YoutubeVideoLoader(BaseLoader): + """Loader for YouTube videos.""" + + def load(self, source: SourceContent, **kwargs) -> LoaderResult: + """Load and extract transcript from a YouTube video. + + Args: + source: The source content containing the YouTube URL + + Returns: + LoaderResult with transcript content + + Raises: + ImportError: If required YouTube libraries aren't installed + ValueError: If the URL is not a valid YouTube video URL + """ + try: + from youtube_transcript_api import YouTubeTranscriptApi + except ImportError: + raise ImportError( + "YouTube support requires youtube-transcript-api. " + "Install with: uv add youtube-transcript-api" + ) + + video_url = source.source + video_id = self._extract_video_id(video_url) + + if not video_id: + raise ValueError(f"Invalid YouTube URL: {video_url}") + + metadata: dict[str, Any] = { + "source": video_url, + "video_id": video_id, + "data_type": "youtube_video" + } + + try: + api = YouTubeTranscriptApi() + transcript_list = api.list(video_id) + + transcript = None + try: + transcript = transcript_list.find_transcript(['en']) + except: + try: + transcript = transcript_list.find_generated_transcript(['en']) + except: + transcript = next(iter(transcript_list)) + + if transcript: + metadata["language"] = transcript.language + metadata["is_generated"] = transcript.is_generated + + transcript_data = transcript.fetch() + + text_content = [] + for entry in transcript_data: + text = entry.text.strip() if hasattr(entry, 'text') else '' + if text: + text_content.append(text) + + content = ' '.join(text_content) + + try: + from pytube import YouTube + yt = YouTube(video_url) + metadata["title"] = yt.title + metadata["author"] = yt.author + metadata["length_seconds"] = yt.length + metadata["description"] = yt.description[:500] if yt.description else None + + if yt.title: + content = f"Title: {yt.title}\n\nAuthor: {yt.author or 'Unknown'}\n\nTranscript:\n{content}" + except: + pass + else: + raise ValueError(f"No transcript available for YouTube video: {video_id}") + + except Exception as e: + raise ValueError(f"Unable to extract transcript from YouTube video {video_id}: {str(e)}") from e + + return LoaderResult( + content=content, + source=video_url, + metadata=metadata, + doc_id=self.generate_doc_id(source_ref=video_url, content=content) + ) + + def _extract_video_id(self, url: str) -> str | None: + """Extract video ID from various YouTube URL formats.""" + patterns = [ + r'(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)', + ] + + for pattern in patterns: + match = re.search(pattern, url) + if match: + return match.group(1) + + try: + parsed = urlparse(url) + hostname = parsed.hostname + if hostname: + hostname_lower = hostname.lower() + # Allow youtube.com and any subdomain of youtube.com, plus youtu.be shortener + if hostname_lower == 'youtube.com' or hostname_lower.endswith('.youtube.com') or hostname_lower == 'youtu.be': + query_params = parse_qs(parsed.query) + if 'v' in query_params: + return query_params['v'][0] + except: + pass + + return None \ No newline at end of file diff --git a/src/crewai_tools/rag/misc.py b/src/crewai_tools/rag/misc.py index 5b95f804e..edec22f80 100644 --- a/src/crewai_tools/rag/misc.py +++ b/src/crewai_tools/rag/misc.py @@ -1,4 +1,29 @@ import hashlib +from typing import Any def compute_sha256(content: str) -> str: return hashlib.sha256(content.encode("utf-8")).hexdigest() + +def sanitize_metadata_for_chromadb(metadata: dict[str, Any]) -> dict[str, Any]: + """Sanitize metadata to ensure ChromaDB compatibility. + + ChromaDB only accepts str, int, float, or bool values in metadata. + This function converts other types to strings. + + Args: + metadata: Dictionary of metadata to sanitize + + Returns: + Sanitized metadata dictionary with only ChromaDB-compatible types + """ + sanitized = {} + for key, value in metadata.items(): + if isinstance(value, (str, int, float, bool)) or value is None: + sanitized[key] = value + elif isinstance(value, (list, tuple)): + # Convert lists/tuples to pipe-separated strings + sanitized[key] = " | ".join(str(v) for v in value) + else: + # Convert other types to string + sanitized[key] = str(value) + return sanitized diff --git a/src/crewai_tools/tools/code_docs_search_tool/code_docs_search_tool.py b/src/crewai_tools/tools/code_docs_search_tool/code_docs_search_tool.py index 155b4390d..85be97894 100644 --- a/src/crewai_tools/tools/code_docs_search_tool/code_docs_search_tool.py +++ b/src/crewai_tools/tools/code_docs_search_tool/code_docs_search_tool.py @@ -1,14 +1,10 @@ from typing import Any, Optional, Type -try: - from embedchain.models.data_type import DataType - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False from pydantic import BaseModel, Field from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class FixedCodeDocsSearchToolSchema(BaseModel): @@ -42,15 +38,15 @@ class CodeDocsSearchTool(RagTool): self._generate_description() def add(self, docs_url: str) -> None: - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") super().add(docs_url, data_type=DataType.DOCS_SITE) def _run( self, search_query: str, docs_url: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if docs_url is not None: self.add(docs_url) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/csv_search_tool/csv_search_tool.py b/src/crewai_tools/tools/csv_search_tool/csv_search_tool.py index 4be84efdd..ac95b1df5 100644 --- a/src/crewai_tools/tools/csv_search_tool/csv_search_tool.py +++ b/src/crewai_tools/tools/csv_search_tool/csv_search_tool.py @@ -1,14 +1,10 @@ from typing import Optional, Type -try: - from embedchain.models.data_type import DataType - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False from pydantic import BaseModel, Field from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class FixedCSVSearchToolSchema(BaseModel): @@ -42,15 +38,16 @@ class CSVSearchTool(RagTool): self._generate_description() def add(self, csv: str) -> None: - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") super().add(csv, data_type=DataType.CSV) def _run( self, search_query: str, csv: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if csv is not None: self.add(csv) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) + diff --git a/src/crewai_tools/tools/directory_search_tool/directory_search_tool.py b/src/crewai_tools/tools/directory_search_tool/directory_search_tool.py index 30fdd52cc..9f0765f2d 100644 --- a/src/crewai_tools/tools/directory_search_tool/directory_search_tool.py +++ b/src/crewai_tools/tools/directory_search_tool/directory_search_tool.py @@ -1,14 +1,9 @@ from typing import Optional, Type -try: - from embedchain.loaders.directory_loader import DirectoryLoader - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False - from pydantic import BaseModel, Field from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class FixedDirectorySearchToolSchema(BaseModel): @@ -34,8 +29,6 @@ class DirectorySearchTool(RagTool): args_schema: Type[BaseModel] = DirectorySearchToolSchema def __init__(self, directory: Optional[str] = None, **kwargs): - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") super().__init__(**kwargs) if directory is not None: self.add(directory) @@ -44,16 +37,15 @@ class DirectorySearchTool(RagTool): self._generate_description() def add(self, directory: str) -> None: - super().add( - directory, - loader=DirectoryLoader(config=dict(recursive=True)), - ) + super().add(directory, data_type=DataType.DIRECTORY) def _run( self, search_query: str, directory: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if directory is not None: self.add(directory) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/docx_search_tool/docx_search_tool.py b/src/crewai_tools/tools/docx_search_tool/docx_search_tool.py index 97dab02cd..9a33bade9 100644 --- a/src/crewai_tools/tools/docx_search_tool/docx_search_tool.py +++ b/src/crewai_tools/tools/docx_search_tool/docx_search_tool.py @@ -1,14 +1,10 @@ from typing import Any, Optional, Type -try: - from embedchain.models.data_type import DataType - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False from pydantic import BaseModel, Field from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class FixedDOCXSearchToolSchema(BaseModel): @@ -48,15 +44,15 @@ class DOCXSearchTool(RagTool): self._generate_description() def add(self, docx: str) -> None: - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") super().add(docx, data_type=DataType.DOCX) def _run( self, search_query: str, docx: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> Any: if docx is not None: self.add(docx) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/github_search_tool/github_search_tool.py b/src/crewai_tools/tools/github_search_tool/github_search_tool.py index afde4fe92..3a0fe42b6 100644 --- a/src/crewai_tools/tools/github_search_tool/github_search_tool.py +++ b/src/crewai_tools/tools/github_search_tool/github_search_tool.py @@ -1,14 +1,9 @@ -from typing import List, Optional, Type, Any +from typing import List, Optional, Type -try: - from embedchain.loaders.github import GithubLoader - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False - -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import BaseModel, Field from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class FixedGithubSearchToolSchema(BaseModel): @@ -42,7 +37,6 @@ class GithubSearchTool(RagTool): default_factory=lambda: ["code", "repo", "pr", "issue"], description="Content types you want to be included search, options: [code, repo, pr, issue]", ) - _loader: Any | None = PrivateAttr(default=None) def __init__( self, @@ -50,10 +44,7 @@ class GithubSearchTool(RagTool): content_types: Optional[List[str]] = None, **kwargs, ): - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") super().__init__(**kwargs) - self._loader = GithubLoader(config={"token": self.gh_token}) if github_repo and content_types: self.add(repo=github_repo, content_types=content_types) @@ -67,11 +58,10 @@ class GithubSearchTool(RagTool): content_types: Optional[List[str]] = None, ) -> None: content_types = content_types or self.content_types - super().add( - f"repo:{repo} type:{','.join(content_types)}", - data_type="github", - loader=self._loader, + f"https://github.com/{repo}", + data_type=DataType.GITHUB, + metadata={"content_types": content_types, "gh_token": self.gh_token} ) def _run( @@ -79,10 +69,12 @@ class GithubSearchTool(RagTool): search_query: str, github_repo: Optional[str] = None, content_types: Optional[List[str]] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if github_repo: self.add( repo=github_repo, content_types=content_types, ) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/json_search_tool/json_search_tool.py b/src/crewai_tools/tools/json_search_tool/json_search_tool.py index 820323eec..49dad0ac7 100644 --- a/src/crewai_tools/tools/json_search_tool/json_search_tool.py +++ b/src/crewai_tools/tools/json_search_tool/json_search_tool.py @@ -41,7 +41,9 @@ class JSONSearchTool(RagTool): self, search_query: str, json_path: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if json_path is not None: self.add(json_path) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/mdx_search_tool/mdx_search_tool.py b/src/crewai_tools/tools/mdx_search_tool/mdx_search_tool.py index 807da62fe..3390b8dba 100644 --- a/src/crewai_tools/tools/mdx_search_tool/mdx_search_tool.py +++ b/src/crewai_tools/tools/mdx_search_tool/mdx_search_tool.py @@ -2,13 +2,9 @@ from typing import Optional, Type from pydantic import BaseModel, Field -try: - from embedchain.models.data_type import DataType - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class FixedMDXSearchToolSchema(BaseModel): @@ -42,15 +38,15 @@ class MDXSearchTool(RagTool): self._generate_description() def add(self, mdx: str) -> None: - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") super().add(mdx, data_type=DataType.MDX) def _run( self, search_query: str, mdx: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if mdx is not None: self.add(mdx) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py b/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py index 8c2c5ef5d..c97585b4e 100644 --- a/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py +++ b/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py @@ -1,14 +1,9 @@ from typing import Any, Type -try: - from embedchain.loaders.mysql import MySQLLoader - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False - from pydantic import BaseModel, Field from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class MySQLSearchToolSchema(BaseModel): @@ -27,12 +22,8 @@ class MySQLSearchTool(RagTool): db_uri: str = Field(..., description="Mandatory database URI") def __init__(self, table_name: str, **kwargs): - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") super().__init__(**kwargs) - kwargs["data_type"] = "mysql" - kwargs["loader"] = MySQLLoader(config=dict(url=self.db_uri)) - self.add(table_name) + self.add(table_name, data_type=DataType.MYSQL, metadata={"db_uri": self.db_uri}) self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content." self._generate_description() @@ -46,6 +37,8 @@ class MySQLSearchTool(RagTool): def _run( self, search_query: str, + similarity_threshold: float | None = None, + limit: int | None = None, **kwargs: Any, ) -> Any: - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py b/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py index 96f141c17..9ab1f29ea 100644 --- a/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py +++ b/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py @@ -2,13 +2,8 @@ from typing import Optional, Type from pydantic import BaseModel, Field -try: - from embedchain.models.data_type import DataType - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False - from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class FixedPDFSearchToolSchema(BaseModel): @@ -41,15 +36,15 @@ class PDFSearchTool(RagTool): self._generate_description() def add(self, pdf: str) -> None: - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") super().add(pdf, data_type=DataType.PDF_FILE) def _run( self, query: str, pdf: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if pdf is not None: self.add(pdf) - return super()._run(query=query) + return super()._run(query=query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/pg_search_tool/pg_search_tool.py b/src/crewai_tools/tools/pg_search_tool/pg_search_tool.py index 30e294944..31f2e697c 100644 --- a/src/crewai_tools/tools/pg_search_tool/pg_search_tool.py +++ b/src/crewai_tools/tools/pg_search_tool/pg_search_tool.py @@ -1,14 +1,9 @@ from typing import Any, Type -try: - from embedchain.loaders.postgres import PostgresLoader - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False - from pydantic import BaseModel, Field from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class PGSearchToolSchema(BaseModel): @@ -27,12 +22,8 @@ class PGSearchTool(RagTool): db_uri: str = Field(..., description="Mandatory database URI") def __init__(self, table_name: str, **kwargs): - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") super().__init__(**kwargs) - kwargs["data_type"] = "postgres" - kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri)) - self.add(table_name) + self.add(table_name, data_type=DataType.POSTGRES, metadata={"db_uri": self.db_uri}) self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content." self._generate_description() @@ -46,6 +37,8 @@ class PGSearchTool(RagTool): def _run( self, search_query: str, + similarity_threshold: float | None = None, + limit: int | None = None, **kwargs: Any, ) -> Any: - return super()._run(query=search_query, **kwargs) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit, **kwargs) diff --git a/src/crewai_tools/tools/rag/rag_tool.py b/src/crewai_tools/tools/rag/rag_tool.py index 1a9fad8b8..2397eac6f 100644 --- a/src/crewai_tools/tools/rag/rag_tool.py +++ b/src/crewai_tools/tools/rag/rag_tool.py @@ -1,17 +1,22 @@ -import portalocker - +import os from abc import ABC, abstractmethod -from typing import Any -from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing import Any, cast +from crewai.rag.embeddings.factory import get_embedding_function from crewai.tools import BaseTool +from pydantic import BaseModel, ConfigDict, Field, model_validator class Adapter(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) @abstractmethod - def query(self, question: str) -> str: + def query( + self, + question: str, + similarity_threshold: float | None = None, + limit: int | None = None, + ) -> str: """Query the knowledge base with a question and return the answer.""" @abstractmethod @@ -25,7 +30,12 @@ class Adapter(BaseModel, ABC): class RagTool(BaseTool): class _AdapterPlaceholder(Adapter): - def query(self, question: str) -> str: + def query( + self, + question: str, + similarity_threshold: float | None = None, + limit: int | None = None, + ) -> str: raise NotImplementedError def add(self, *args: Any, **kwargs: Any) -> None: @@ -34,28 +44,149 @@ class RagTool(BaseTool): name: str = "Knowledge base" description: str = "A knowledge base that can be used to answer questions." summarize: bool = False + similarity_threshold: float = 0.6 + limit: int = 5 adapter: Adapter = Field(default_factory=_AdapterPlaceholder) - config: dict[str, Any] | None = None + config: Any | None = None @model_validator(mode="after") def _set_default_adapter(self): if isinstance(self.adapter, RagTool._AdapterPlaceholder): - try: - from embedchain import App - except ImportError: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") + from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter - from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter + parsed_config = self._parse_config(self.config) - with portalocker.Lock("crewai-rag-tool.lock", timeout=10): - app = App.from_config(config=self.config) if self.config else App() - - self.adapter = EmbedchainAdapter( - embedchain_app=app, summarize=self.summarize + self.adapter = CrewAIRagAdapter( + collection_name="rag_tool_collection", + summarize=self.summarize, + similarity_threshold=self.similarity_threshold, + limit=self.limit, + config=parsed_config, ) return self + def _parse_config(self, config: Any) -> Any: + """Parse complex config format to extract provider-specific config. + + Raises: + ValueError: If the config format is invalid or uses unsupported providers. + """ + if config is None: + return None + + if isinstance(config, dict) and "provider" in config: + return config + + if isinstance(config, dict): + if "vectordb" in config: + vectordb_config = config["vectordb"] + if isinstance(vectordb_config, dict) and "provider" in vectordb_config: + provider = vectordb_config["provider"] + provider_config = vectordb_config.get("config", {}) + + supported_providers = ["chromadb", "qdrant"] + if provider not in supported_providers: + raise ValueError( + f"Unsupported vector database provider: '{provider}'. " + f"CrewAI RAG currently supports: {', '.join(supported_providers)}." + ) + + embedding_config = config.get("embedding_model") + embedding_function = None + if embedding_config and isinstance(embedding_config, dict): + embedding_function = self._create_embedding_function( + embedding_config, provider + ) + + return self._create_provider_config( + provider, provider_config, embedding_function + ) + else: + return None + else: + embedding_config = config.get("embedding_model") + embedding_function = None + if embedding_config and isinstance(embedding_config, dict): + embedding_function = self._create_embedding_function( + embedding_config, "chromadb" + ) + + return self._create_provider_config("chromadb", {}, embedding_function) + return config + + @staticmethod + def _create_embedding_function(embedding_config: dict, provider: str) -> Any: + """Create embedding function for the specified vector database provider.""" + embedding_provider = embedding_config.get("provider") + embedding_model_config = embedding_config.get("config", {}).copy() + + if "model" in embedding_model_config: + embedding_model_config["model_name"] = embedding_model_config.pop("model") + + factory_config = {"provider": embedding_provider, **embedding_model_config} + + if embedding_provider == "openai" and "api_key" not in factory_config: + api_key = os.getenv("OPENAI_API_KEY") + if api_key: + factory_config["api_key"] = api_key + + print(f"Creating embedding function with config: {factory_config}") + + if provider == "chromadb": + embedding_func = get_embedding_function(factory_config) + print(f"Created embedding function: {embedding_func}") + print(f"Embedding function type: {type(embedding_func)}") + return embedding_func + + elif provider == "qdrant": + chromadb_func = get_embedding_function(factory_config) + + def qdrant_embed_fn(text: str) -> list[float]: + """Embed text using ChromaDB function and convert to list of floats for Qdrant. + + Args: + text: The input text to embed. + + Returns: + A list of floats representing the embedding. + """ + embeddings = chromadb_func([text]) + return embeddings[0] if embeddings and len(embeddings) > 0 else [] + + return cast(Any, qdrant_embed_fn) + + return None + + @staticmethod + def _create_provider_config( + provider: str, provider_config: dict, embedding_function: Any + ) -> Any: + """Create proper provider config object.""" + if provider == "chromadb": + from crewai.rag.chromadb.config import ChromaDBConfig + + config_kwargs = {} + if embedding_function: + config_kwargs["embedding_function"] = embedding_function + + config_kwargs.update(provider_config) + + return ChromaDBConfig(**config_kwargs) + + elif provider == "qdrant": + from crewai.rag.qdrant.config import QdrantConfig + + config_kwargs = {} + if embedding_function: + config_kwargs["embedding_function"] = embedding_function + + config_kwargs.update(provider_config) + + return QdrantConfig(**config_kwargs) + + return None + def add( self, *args: Any, @@ -66,5 +197,13 @@ class RagTool(BaseTool): def _run( self, query: str, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: - return f"Relevant Content:\n{self.adapter.query(query)}" + threshold = ( + similarity_threshold + if similarity_threshold is not None + else self.similarity_threshold + ) + result_limit = limit if limit is not None else self.limit + return f"Relevant Content:\n{self.adapter.query(query, similarity_threshold=threshold, limit=result_limit)}" diff --git a/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py b/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py index 93d696ab1..2ccfa4eb2 100644 --- a/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py +++ b/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py @@ -39,7 +39,9 @@ class TXTSearchTool(RagTool): self, search_query: str, txt: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if txt is not None: self.add(txt) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/website_search/website_search_tool.py b/src/crewai_tools/tools/website_search/website_search_tool.py index 9728b44db..ac8084d3f 100644 --- a/src/crewai_tools/tools/website_search/website_search_tool.py +++ b/src/crewai_tools/tools/website_search/website_search_tool.py @@ -1,14 +1,9 @@ from typing import Any, Optional, Type -try: - from embedchain.models.data_type import DataType - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False - from pydantic import BaseModel, Field from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class FixedWebsiteSearchToolSchema(BaseModel): @@ -44,15 +39,15 @@ class WebsiteSearchTool(RagTool): self._generate_description() def add(self, website: str) -> None: - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") - super().add(website, data_type=DataType.WEB_PAGE) + super().add(website, data_type=DataType.WEBSITE) def _run( self, search_query: str, website: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if website is not None: self.add(website) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/xml_search_tool/xml_search_tool.py b/src/crewai_tools/tools/xml_search_tool/xml_search_tool.py index 426b0ca38..8509c2d42 100644 --- a/src/crewai_tools/tools/xml_search_tool/xml_search_tool.py +++ b/src/crewai_tools/tools/xml_search_tool/xml_search_tool.py @@ -39,7 +39,9 @@ class XMLSearchTool(RagTool): self, search_query: str, xml: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if xml is not None: self.add(xml) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py b/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py index 6d16a708d..80f597665 100644 --- a/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py +++ b/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py @@ -1,14 +1,9 @@ from typing import Any, Optional, Type -try: - from embedchain.models.data_type import DataType - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False - from pydantic import BaseModel, Field from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class FixedYoutubeChannelSearchToolSchema(BaseModel): @@ -55,7 +50,9 @@ class YoutubeChannelSearchTool(RagTool): self, search_query: str, youtube_channel_handle: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if youtube_channel_handle is not None: self.add(youtube_channel_handle) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py b/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py index b93cc6c29..000c81cec 100644 --- a/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py +++ b/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py @@ -1,14 +1,10 @@ from typing import Any, Optional, Type -try: - from embedchain.models.data_type import DataType - EMBEDCHAIN_AVAILABLE = True -except ImportError: - EMBEDCHAIN_AVAILABLE = False from pydantic import BaseModel, Field from ..rag.rag_tool import RagTool +from crewai_tools.rag.data_types import DataType class FixedYoutubeVideoSearchToolSchema(BaseModel): @@ -44,15 +40,15 @@ class YoutubeVideoSearchTool(RagTool): self._generate_description() def add(self, youtube_video_url: str) -> None: - if not EMBEDCHAIN_AVAILABLE: - raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`") super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO) def _run( self, search_query: str, youtube_video_url: Optional[str] = None, + similarity_threshold: float | None = None, + limit: int | None = None, ) -> str: if youtube_video_url is not None: self.add(youtube_video_url) - return super()._run(query=search_query) + return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit) diff --git a/tests/tools/rag/rag_tool_test.py b/tests/tools/rag/rag_tool_test.py index 42baccc2c..693cd120a 100644 --- a/tests/tools/rag/rag_tool_test.py +++ b/tests/tools/rag/rag_tool_test.py @@ -1,43 +1,54 @@ -import os -from tempfile import NamedTemporaryFile +from tempfile import TemporaryDirectory from typing import cast -from unittest import mock +from pathlib import Path -from pytest import fixture -from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter +from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter from crewai_tools.tools.rag.rag_tool import RagTool -@fixture(autouse=True) -def mock_embedchain_db_uri(): - with NamedTemporaryFile() as tmp: - uri = f"sqlite:///{tmp.name}" - with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}): - yield - - -def test_custom_llm_and_embedder(): +def test_rag_tool_initialization(): + """Test that RagTool initializes with CrewAI adapter by default.""" class MyTool(RagTool): pass - tool = MyTool( - config=dict( - llm=dict( - provider="openai", - config=dict(model="gpt-3.5-custom"), - ), - embedder=dict( - provider="openai", - config=dict(model="text-embedding-3-custom"), - ), - ) - ) + tool = MyTool() assert tool.adapter is not None - assert isinstance(tool.adapter, EmbedchainAdapter) + assert isinstance(tool.adapter, CrewAIRagAdapter) + + adapter = cast(CrewAIRagAdapter, tool.adapter) + assert adapter.collection_name == "rag_tool_collection" + assert adapter._client is not None - adapter = cast(EmbedchainAdapter, tool.adapter) - assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom" - assert ( - adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom" - ) + +def test_rag_tool_add_and_query(): + """Test adding content and querying with RagTool.""" + class MyTool(RagTool): + pass + + tool = MyTool() + + tool.add("The sky is blue on a clear day.") + tool.add("Machine learning is a subset of artificial intelligence.") + + result = tool._run(query="What color is the sky?") + assert "Relevant Content:" in result + + result = tool._run(query="Tell me about machine learning") + assert "Relevant Content:" in result + + +def test_rag_tool_with_file(): + """Test RagTool with file content.""" + with TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "test.txt" + test_file.write_text("Python is a programming language known for its simplicity.") + + class MyTool(RagTool): + pass + + tool = MyTool() + tool.add(str(test_file)) + + result = tool._run(query="What is Python?") + assert "Relevant Content:" in result diff --git a/tests/tools/test_search_tools.py b/tests/tools/test_search_tools.py index eaa0c591c..b912ef005 100644 --- a/tests/tools/test_search_tools.py +++ b/tests/tools/test_search_tools.py @@ -1,11 +1,11 @@ import os import tempfile from pathlib import Path -from unittest.mock import ANY, MagicMock +from unittest.mock import MagicMock import pytest -from embedchain.models.data_type import DataType +from crewai_tools.rag.data_types import DataType from crewai_tools.tools import ( CodeDocsSearchTool, CSVSearchTool, @@ -49,7 +49,7 @@ def test_pdf_search_tool(mock_adapter): result = tool._run(query="test content") assert "this is a test" in result.lower() mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE) - mock_adapter.query.assert_called_once_with("test content") + mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5) mock_adapter.query.reset_mock() mock_adapter.add.reset_mock() @@ -58,7 +58,7 @@ def test_pdf_search_tool(mock_adapter): result = tool._run(pdf="test.pdf", query="test content") assert "this is a test" in result.lower() mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE) - mock_adapter.query.assert_called_once_with("test content") + mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5) def test_txt_search_tool(): @@ -82,7 +82,7 @@ def test_docx_search_tool(mock_adapter): result = tool._run(search_query="test content") assert "this is a test" in result.lower() mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX) - mock_adapter.query.assert_called_once_with("test content") + mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5) mock_adapter.query.reset_mock() mock_adapter.add.reset_mock() @@ -91,7 +91,7 @@ def test_docx_search_tool(mock_adapter): result = tool._run(docx="test.docx", search_query="test content") assert "this is a test" in result.lower() mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX) - mock_adapter.query.assert_called_once_with("test content") + mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5) def test_json_search_tool(): @@ -114,7 +114,7 @@ def test_xml_search_tool(mock_adapter): result = tool._run(search_query="test XML", xml="test.xml") assert "this is a test" in result.lower() mock_adapter.add.assert_called_once_with("test.xml") - mock_adapter.query.assert_called_once_with("test XML") + mock_adapter.query.assert_called_once_with("test XML", similarity_threshold=0.6, limit=5) def test_csv_search_tool(): @@ -153,8 +153,8 @@ def test_website_search_tool(mock_adapter): tool = WebsiteSearchTool(website=website, adapter=mock_adapter) result = tool._run(search_query=search_query) - mock_adapter.query.assert_called_once_with("what is crewai?") - mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE) + mock_adapter.query.assert_called_once_with("what is crewai?", similarity_threshold=0.6, limit=5) + mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE) assert "this is a test" in result.lower() @@ -164,8 +164,8 @@ def test_website_search_tool(mock_adapter): tool = WebsiteSearchTool(adapter=mock_adapter) result = tool._run(website=website, search_query=search_query) - mock_adapter.query.assert_called_once_with("what is crewai?") - mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE) + mock_adapter.query.assert_called_once_with("what is crewai?", similarity_threshold=0.6, limit=5) + mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE) assert "this is a test" in result.lower() @@ -185,7 +185,7 @@ def test_youtube_video_search_tool(mock_adapter): mock_adapter.add.assert_called_once_with( youtube_video_url, data_type=DataType.YOUTUBE_VIDEO ) - mock_adapter.query.assert_called_once_with(search_query) + mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5) mock_adapter.query.reset_mock() mock_adapter.add.reset_mock() @@ -197,7 +197,7 @@ def test_youtube_video_search_tool(mock_adapter): mock_adapter.add.assert_called_once_with( youtube_video_url, data_type=DataType.YOUTUBE_VIDEO ) - mock_adapter.query.assert_called_once_with(search_query) + mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5) def test_youtube_channel_search_tool(mock_adapter): @@ -213,7 +213,7 @@ def test_youtube_channel_search_tool(mock_adapter): mock_adapter.add.assert_called_once_with( youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL ) - mock_adapter.query.assert_called_once_with(search_query) + mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5) mock_adapter.query.reset_mock() mock_adapter.add.reset_mock() @@ -227,7 +227,7 @@ def test_youtube_channel_search_tool(mock_adapter): mock_adapter.add.assert_called_once_with( youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL ) - mock_adapter.query.assert_called_once_with(search_query) + mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5) def test_code_docs_search_tool(mock_adapter): @@ -239,7 +239,7 @@ def test_code_docs_search_tool(mock_adapter): result = tool._run(search_query=search_query) assert "test documentation" in result mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE) - mock_adapter.query.assert_called_once_with(search_query) + mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5) mock_adapter.query.reset_mock() mock_adapter.add.reset_mock() @@ -248,7 +248,7 @@ def test_code_docs_search_tool(mock_adapter): result = tool._run(docs_url=docs_url, search_query=search_query) assert "test documentation" in result mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE) - mock_adapter.query.assert_called_once_with(search_query) + mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5) def test_github_search_tool(mock_adapter): @@ -264,9 +264,11 @@ def test_github_search_tool(mock_adapter): result = tool._run(search_query="tell me about crewai repo") assert "repo description" in result mock_adapter.add.assert_called_once_with( - "repo:crewai/crewai type:code", data_type="github", loader=ANY + "https://github.com/crewai/crewai", + data_type=DataType.GITHUB, + metadata={"content_types": ["code"], "gh_token": "test_token"} ) - mock_adapter.query.assert_called_once_with("tell me about crewai repo") + mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5) # ensure content types provided by run call is used mock_adapter.query.reset_mock() @@ -280,9 +282,11 @@ def test_github_search_tool(mock_adapter): ) assert "repo description" in result mock_adapter.add.assert_called_once_with( - "repo:crewai/crewai type:code,issue", data_type="github", loader=ANY + "https://github.com/crewai/crewai", + data_type=DataType.GITHUB, + metadata={"content_types": ["code", "issue"], "gh_token": "test_token"} ) - mock_adapter.query.assert_called_once_with("tell me about crewai repo") + mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5) # ensure default content types are used if not provided mock_adapter.query.reset_mock() @@ -295,9 +299,11 @@ def test_github_search_tool(mock_adapter): ) assert "repo description" in result mock_adapter.add.assert_called_once_with( - "repo:crewai/crewai type:code,repo,pr,issue", data_type="github", loader=ANY + "https://github.com/crewai/crewai", + data_type=DataType.GITHUB, + metadata={"content_types": ["code", "repo", "pr", "issue"], "gh_token": "test_token"} ) - mock_adapter.query.assert_called_once_with("tell me about crewai repo") + mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5) # ensure nothing is added if no repo is provided mock_adapter.query.reset_mock() @@ -306,4 +312,4 @@ def test_github_search_tool(mock_adapter): tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter) result = tool._run(search_query="tell me about crewai repo") mock_adapter.add.assert_not_called() - mock_adapter.query.assert_called_once_with("tell me about crewai repo") + mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)