fix: Remove kwargs from all RagTools (#285)

* fix: remove kwargs from all (except mysql & pg) RagTools

The agent uses the tool description to decide what to propagate when a tool with **kwargs is found, but this often leads to failures during the tool invocation step.

This happens because the final description ends up like this:

```
CrewStructuredTool(name='Knowledge base', description='Tool Name: Knowledge base
Tool Arguments: {'query': {'description': None, 'type': 'str'}, 'kwargs': {'description': None, 'type': 'Any'}}
Tool Description: A knowledge base that can be used to answer questions.')
```

The agent then tries to infer and pass a kwargs parameter, which isn’t supported by the schema at all.

* feat: adding test to search tools

* feat: add db (chromadb folder) to .gitignore

* fix: fix github search integration

A few attributes were missing when calling the .add method: data_type and loader.

Also, update the query search according to the EmbedChain documentation, the query must include the type and repo keys

* fix: rollback YoutubeChannel paramenter

* chore: fix type hinting for CodeDocs search

* fix: ensure proper configuration when call `add`

According to the documentation, some search methods must be defined as either a loader or a data_type. This commit ensures that.

* build: add optional-dependencies for github and xml search

* test: mocking external requests from search_tool tests

* build: add pytest-recording as devDependencie
This commit is contained in:
Lucas Gomide
2025-05-05 15:15:50 -03:00
committed by GitHub
parent 93d043bcd4
commit fd4ef4f47a
23 changed files with 2051 additions and 279 deletions

View File

@@ -31,30 +31,19 @@ class CodeDocsSearchTool(RagTool):
def __init__(self, docs_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if docs_url is not None:
kwargs["data_type"] = DataType.DOCS_SITE
self.add(docs_url)
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
self.args_schema = FixedCodeDocsSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "docs_url" in kwargs:
self.add(kwargs["docs_url"])
def add(self, docs_url: str) -> None:
super().add(docs_url, data_type=DataType.DOCS_SITE)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
docs_url: Optional[str] = None,
) -> str:
if docs_url is not None:
self.add(docs_url)
return super()._run(query=search_query)

View File

@@ -31,30 +31,19 @@ class CSVSearchTool(RagTool):
def __init__(self, csv: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if csv is not None:
kwargs["data_type"] = DataType.CSV
self.add(csv)
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
self.args_schema = FixedCSVSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "csv" in kwargs:
self.add(kwargs["csv"])
def add(self, csv: str) -> None:
super().add(csv, data_type=DataType.CSV)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
csv: Optional[str] = None,
) -> str:
if csv is not None:
self.add(csv)
return super()._run(query=search_query)

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type
from typing import Optional, Type
from embedchain.loaders.directory_loader import DirectoryLoader
from pydantic import BaseModel, Field
@@ -31,30 +31,22 @@ class DirectorySearchTool(RagTool):
def __init__(self, directory: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if directory is not None:
kwargs["loader"] = DirectoryLoader(config=dict(recursive=True))
self.add(directory)
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
self.args_schema = FixedDirectorySearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "directory" in kwargs:
self.add(kwargs["directory"])
def add(self, directory: str) -> None:
super().add(
directory,
loader=DirectoryLoader(config=dict(recursive=True)),
)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
directory: Optional[str] = None,
) -> str:
if directory is not None:
self.add(directory)
return super()._run(query=search_query)

View File

@@ -37,36 +37,19 @@ class DOCXSearchTool(RagTool):
def __init__(self, docx: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if docx is not None:
kwargs["data_type"] = DataType.DOCX
self.add(docx)
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
self.args_schema = FixedDOCXSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "docx" in kwargs:
self.add(kwargs["docx"])
def add(self, docx: str) -> None:
super().add(docx, data_type=DataType.DOCX)
def _run(
self,
**kwargs: Any,
search_query: str,
docx: Optional[str] = None,
) -> Any:
search_query = kwargs.get("search_query")
if search_query is None:
search_query = kwargs.get("query")
docx = kwargs.get("docx")
if docx is not None:
self.add(docx)
return super()._run(query=search_query, **kwargs)
return super()._run(query=search_query)

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Optional, Type
from typing import List, Optional, Type
from embedchain.loaders.github import GithubLoader
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from ..rag.rag_tool import RagTool
@@ -27,19 +27,29 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
class GithubSearchTool(RagTool):
name: str = "Search a github repo's content"
description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
description: str = (
"A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
)
summarize: bool = False
gh_token: str
args_schema: Type[BaseModel] = GithubSearchToolSchema
content_types: List[str]
content_types: List[str] = Field(
default_factory=lambda: ["code", "repo", "pr", "issue"],
description="Content types you want to be included search, options: [code, repo, pr, issue]",
)
_loader: GithubLoader | None = PrivateAttr(default=None)
def __init__(self, github_repo: Optional[str] = None, **kwargs):
def __init__(
self,
github_repo: Optional[str] = None,
content_types: Optional[List[str]] = None,
**kwargs,
):
super().__init__(**kwargs)
if github_repo is not None:
kwargs["data_type"] = "github"
kwargs["loader"] = GithubLoader(config={"token": self.gh_token})
self._loader = GithubLoader(config={"token": self.gh_token})
self.add(repo=github_repo)
if github_repo and content_types:
self.add(repo=github_repo, content_types=content_types)
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
self.args_schema = FixedGithubSearchToolSchema
self._generate_description()
@@ -47,26 +57,25 @@ class GithubSearchTool(RagTool):
def add(
self,
repo: str,
content_types: List[str] | None = None,
**kwargs: Any,
content_types: Optional[List[str]] = None,
) -> None:
content_types = content_types or self.content_types
super().add(f"repo:{repo} type:{','.join(content_types)}", **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "github_repo" in kwargs:
self.add(
repo=kwargs["github_repo"], content_types=kwargs.get("content_types")
)
super().add(
f"repo:{repo} type:{','.join(content_types)}",
data_type="github",
loader=self._loader,
)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
github_repo: Optional[str] = None,
content_types: Optional[List[str]] = None,
) -> str:
if github_repo:
self.add(
repo=github_repo,
content_types=content_types,
)
return super()._run(query=search_query)

View File

@@ -31,30 +31,16 @@ class JSONSearchTool(RagTool):
def __init__(self, json_path: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if json_path is not None:
kwargs["data_type"] = DataType.JSON
self.add(json_path)
self.description = f"A tool that can be used to semantic search a query the {json_path} JSON's content."
self.args_schema = FixedJSONSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "json_path" in kwargs:
self.add(kwargs["json_path"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
json_path: Optional[str] = None,
) -> str:
if json_path is not None:
self.add(json_path)
return super()._run(query=search_query)

View File

@@ -31,30 +31,19 @@ class MDXSearchTool(RagTool):
def __init__(self, mdx: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if mdx is not None:
kwargs["data_type"] = DataType.MDX
self.add(mdx)
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
self.args_schema = FixedMDXSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "mdx" in kwargs:
self.add(kwargs["mdx"])
def add(self, mdx: str) -> None:
super().add(mdx, data_type=DataType.MDX)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
mdx: Optional[str] = None,
) -> str:
if mdx is not None:
self.add(mdx)
return super()._run(query=search_query)

View File

@@ -30,39 +30,19 @@ class PDFSearchTool(RagTool):
def __init__(self, pdf: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if pdf is not None:
kwargs["data_type"] = DataType.PDF_FILE
self.add(pdf)
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
self.args_schema = FixedPDFSearchToolSchema
self._generate_description()
@model_validator(mode="after")
def _set_default_adapter(self):
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
from embedchain import App
def add(self, pdf: str) -> None:
super().add(pdf, data_type=DataType.PDF_FILE)
from crewai_tools.adapters.pdf_embedchain_adapter import (
PDFEmbedchainAdapter,
)
app = App.from_config(config=self.config) if self.config else App()
self.adapter = PDFEmbedchainAdapter(
embedchain_app=app, summarize=self.summarize
)
return self
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
def _run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "pdf" in kwargs:
self.add(kwargs["pdf"])
pdf: Optional[str] = None,
) -> str:
if pdf is not None:
self.add(pdf)
return super()._run(query=query)

View File

@@ -32,7 +32,9 @@ class PDFTextWritingTool(RagTool):
"""A tool to add text to specific positions in a PDF, with custom font support."""
name: str = "PDF Text Writing Tool"
description: str = "A tool that can write text to a specific position in a PDF document, with optional custom font embedding."
description: str = (
"A tool that can write text to a specific position in a PDF document, with optional custom font embedding."
)
args_schema: Type[BaseModel] = PDFTextWritingToolSchema
def run(
@@ -45,7 +47,6 @@ class PDFTextWritingTool(RagTool):
font_name: str = "F1",
font_file: Optional[str] = None,
page_number: int = 0,
**kwargs,
) -> str:
reader = PdfReader(pdf_path)
writer = PdfWriter()

View File

@@ -59,11 +59,5 @@ class RagTool(BaseTool):
def _run(
self,
query: str,
**kwargs: Any,
) -> Any:
self._before_run(query, **kwargs)
) -> str:
return f"Relevant Content:\n{self.adapter.query(query)}"
def _before_run(self, query, **kwargs):
pass

View File

@@ -41,14 +41,15 @@ class SerplyJobSearchTool(RagTool):
def _run(
self,
**kwargs: Any,
) -> Any:
query: Optional[str] = None,
search_query: Optional[str] = None,
) -> str:
query_payload = {}
if "query" in kwargs:
query_payload["q"] = kwargs["query"]
elif "search_query" in kwargs:
query_payload["q"] = kwargs["search_query"]
if query is not None:
query_payload["q"] = query
elif search_query is not None:
query_payload["q"] = search_query
# build the url
url = f"{self.request_url}{urlencode(query_payload)}"

View File

@@ -18,7 +18,9 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel):
class SerplyWebpageToMarkdownTool(RagTool):
name: str = "Webpage to Markdown"
description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
description: str = (
"A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
)
args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema
request_url: str = "https://api.serply.io/v1/request"
proxy_location: Optional[str] = "US"
@@ -39,9 +41,9 @@ class SerplyWebpageToMarkdownTool(RagTool):
def _run(
self,
**kwargs: Any,
) -> Any:
data = {"url": kwargs["url"], "method": "GET", "response_type": "markdown"}
url: str,
) -> str:
data = {"url": url, "method": "GET", "response_type": "markdown"}
response = requests.request(
"POST", self.request_url, headers=self.headers, json=data
)

View File

@@ -1,6 +1,5 @@
from typing import Any, Optional, Type
from typing import Optional, Type
from embedchain.models.data_type import DataType
from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool
@@ -31,30 +30,16 @@ class TXTSearchTool(RagTool):
def __init__(self, txt: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if txt is not None:
kwargs["data_type"] = DataType.TEXT_FILE
self.add(txt)
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
self.args_schema = FixedTXTSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "txt" in kwargs:
self.add(kwargs["txt"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
txt: Optional[str] = None,
) -> str:
if txt is not None:
self.add(txt)
return super()._run(query=search_query)

View File

@@ -25,36 +25,27 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
class WebsiteSearchTool(RagTool):
name: str = "Search in a specific website"
description: str = "A tool that can be used to semantic search a query from a specific URL content."
description: str = (
"A tool that can be used to semantic search a query from a specific URL content."
)
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
def __init__(self, website: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if website is not None:
kwargs["data_type"] = DataType.WEB_PAGE
self.add(website)
self.description = f"A tool that can be used to semantic search a query from {website} website content."
self.args_schema = FixedWebsiteSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "website" in kwargs:
self.add(kwargs["website"])
def add(self, website: str) -> None:
super().add(website, data_type=DataType.WEB_PAGE)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
website: Optional[str] = None,
) -> str:
if website is not None:
self.add(website)
return super()._run(query=search_query)

View File

@@ -31,30 +31,16 @@ class XMLSearchTool(RagTool):
def __init__(self, xml: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if xml is not None:
kwargs["data_type"] = DataType.XML
self.add(xml)
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
self.args_schema = FixedXMLSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "xml" in kwargs:
self.add(kwargs["xml"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
xml: Optional[str] = None,
) -> str:
if xml is not None:
self.add(xml)
return super()._run(query=search_query)

View File

@@ -25,13 +25,14 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
class YoutubeChannelSearchTool(RagTool):
name: str = "Search a Youtube Channels content"
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
description: str = (
"A tool that can be used to semantic search a query from a Youtube Channels content."
)
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if youtube_channel_handle is not None:
kwargs["data_type"] = DataType.YOUTUBE_CHANNEL
self.add(youtube_channel_handle)
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
self.args_schema = FixedYoutubeChannelSearchToolSchema
@@ -40,23 +41,16 @@ class YoutubeChannelSearchTool(RagTool):
def add(
self,
youtube_channel_handle: str,
**kwargs: Any,
) -> None:
if not youtube_channel_handle.startswith("@"):
youtube_channel_handle = f"@{youtube_channel_handle}"
super().add(youtube_channel_handle, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "youtube_channel_handle" in kwargs:
self.add(kwargs["youtube_channel_handle"])
super().add(youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
youtube_channel_handle: Optional[str] = None,
) -> str:
if youtube_channel_handle is not None:
self.add(youtube_channel_handle)
return super()._run(query=search_query)

View File

@@ -25,36 +25,27 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
class YoutubeVideoSearchTool(RagTool):
name: str = "Search a Youtube Video content"
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
description: str = (
"A tool that can be used to semantic search a query from a Youtube Video content."
)
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if youtube_video_url is not None:
kwargs["data_type"] = DataType.YOUTUBE_VIDEO
self.add(youtube_video_url)
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
self.args_schema = FixedYoutubeVideoSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "youtube_video_url" in kwargs:
self.add(kwargs["youtube_video_url"])
def add(self, youtube_video_url: str) -> None:
super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
youtube_video_url: Optional[str] = None,
) -> str:
if youtube_video_url is not None:
self.add(youtube_video_url)
return super()._run(query=search_query)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,309 @@
import os
import tempfile
from pathlib import Path
from unittest.mock import ANY, MagicMock
import pytest
from embedchain.models.data_type import DataType
from crewai_tools.tools import (
CodeDocsSearchTool,
CSVSearchTool,
DirectorySearchTool,
DOCXSearchTool,
GithubSearchTool,
JSONSearchTool,
MDXSearchTool,
PDFSearchTool,
TXTSearchTool,
WebsiteSearchTool,
XMLSearchTool,
YoutubeChannelSearchTool,
YoutubeVideoSearchTool,
)
from crewai_tools.tools.rag.rag_tool import Adapter
pytestmark = [pytest.mark.vcr(filter_headers=["authorization"])]
@pytest.fixture
def mock_adapter():
mock_adapter = MagicMock(spec=Adapter)
return mock_adapter
def test_directory_search_tool():
with tempfile.TemporaryDirectory() as temp_dir:
test_file = Path(temp_dir) / "test.txt"
test_file.write_text("This is a test file for directory search")
tool = DirectorySearchTool(directory=temp_dir)
result = tool._run(search_query="test file")
assert "test file" in result.lower()
def test_pdf_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = PDFSearchTool(pdf="test.pdf", adapter=mock_adapter)
result = tool._run(query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
mock_adapter.query.assert_called_once_with("test content")
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = PDFSearchTool(adapter=mock_adapter)
result = tool._run(pdf="test.pdf", query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
mock_adapter.query.assert_called_once_with("test content")
def test_txt_search_tool():
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file:
temp_file.write(b"This is a test file for txt search")
temp_file_path = temp_file.name
try:
tool = TXTSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test file")
assert "test file" in result.lower()
finally:
os.unlink(temp_file_path)
def test_docx_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = DOCXSearchTool(docx="test.docx", adapter=mock_adapter)
result = tool._run(search_query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
mock_adapter.query.assert_called_once_with("test content")
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = DOCXSearchTool(adapter=mock_adapter)
result = tool._run(docx="test.docx", search_query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
mock_adapter.query.assert_called_once_with("test content")
def test_json_search_tool():
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
temp_file.write(b'{"test": "This is a test JSON file"}')
temp_file_path = temp_file.name
try:
tool = JSONSearchTool()
result = tool._run(search_query="test JSON", json_path=temp_file_path)
assert "test json" in result.lower()
finally:
os.unlink(temp_file_path)
def test_xml_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = XMLSearchTool(adapter=mock_adapter)
result = tool._run(search_query="test XML", xml="test.xml")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.xml")
mock_adapter.query.assert_called_once_with("test XML")
def test_csv_search_tool():
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp_file:
temp_file.write(b"name,description\ntest,This is a test CSV file")
temp_file_path = temp_file.name
try:
tool = CSVSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test CSV")
assert "test csv" in result.lower()
finally:
os.unlink(temp_file_path)
def test_mdx_search_tool():
with tempfile.NamedTemporaryFile(suffix=".mdx", delete=False) as temp_file:
temp_file.write(b"# Test MDX\nThis is a test MDX file")
temp_file_path = temp_file.name
try:
tool = MDXSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test MDX")
assert "test mdx" in result.lower()
finally:
os.unlink(temp_file_path)
def test_website_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
website = "https://crewai.com"
search_query = "what is crewai?"
tool = WebsiteSearchTool(website=website, adapter=mock_adapter)
result = tool._run(search_query=search_query)
mock_adapter.query.assert_called_once_with("what is crewai?")
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
assert "this is a test" in result.lower()
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = WebsiteSearchTool(adapter=mock_adapter)
result = tool._run(website=website, search_query=search_query)
mock_adapter.query.assert_called_once_with("what is crewai?")
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
assert "this is a test" in result.lower()
def test_youtube_video_search_tool(mock_adapter):
mock_adapter.query.return_value = "some video description"
youtube_video_url = "https://www.youtube.com/watch?v=sample-video-id"
search_query = "what is the video about?"
tool = YoutubeVideoSearchTool(
youtube_video_url=youtube_video_url,
adapter=mock_adapter,
)
result = tool._run(search_query=search_query)
assert "some video description" in result
mock_adapter.add.assert_called_once_with(
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
)
mock_adapter.query.assert_called_once_with(search_query)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = YoutubeVideoSearchTool(adapter=mock_adapter)
result = tool._run(youtube_video_url=youtube_video_url, search_query=search_query)
assert "some video description" in result
mock_adapter.add.assert_called_once_with(
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
)
mock_adapter.query.assert_called_once_with(search_query)
def test_youtube_channel_search_tool(mock_adapter):
mock_adapter.query.return_value = "channel description"
youtube_channel_handle = "@crewai"
search_query = "what is the channel about?"
tool = YoutubeChannelSearchTool(
youtube_channel_handle=youtube_channel_handle, adapter=mock_adapter
)
result = tool._run(search_query=search_query)
assert "channel description" in result
mock_adapter.add.assert_called_once_with(
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
)
mock_adapter.query.assert_called_once_with(search_query)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = YoutubeChannelSearchTool(adapter=mock_adapter)
result = tool._run(
youtube_channel_handle=youtube_channel_handle, search_query=search_query
)
assert "channel description" in result
mock_adapter.add.assert_called_once_with(
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
)
mock_adapter.query.assert_called_once_with(search_query)
def test_code_docs_search_tool(mock_adapter):
mock_adapter.query.return_value = "test documentation"
docs_url = "https://crewai.com/any-docs-url"
search_query = "test documentation"
tool = CodeDocsSearchTool(docs_url=docs_url, adapter=mock_adapter)
result = tool._run(search_query=search_query)
assert "test documentation" in result
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
mock_adapter.query.assert_called_once_with(search_query)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = CodeDocsSearchTool(adapter=mock_adapter)
result = tool._run(docs_url=docs_url, search_query=search_query)
assert "test documentation" in result
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
mock_adapter.query.assert_called_once_with(search_query)
def test_github_search_tool(mock_adapter):
mock_adapter.query.return_value = "repo description"
# ensure the provided repo and content types are used after initialization
tool = GithubSearchTool(
gh_token="test_token",
github_repo="crewai/crewai",
content_types=["code"],
adapter=mock_adapter,
)
result = tool._run(search_query="tell me about crewai repo")
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code", data_type="github", loader=ANY
)
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
# ensure content types provided by run call is used
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(
github_repo="crewai/crewai",
content_types=["code", "issue"],
search_query="tell me about crewai repo",
)
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code,issue", data_type="github", loader=ANY
)
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
# ensure default content types are used if not provided
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(
github_repo="crewai/crewai",
search_query="tell me about crewai repo",
)
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code,repo,pr,issue", data_type="github", loader=ANY
)
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
# ensure nothing is added if no repo is provided
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(search_query="tell me about crewai repo")
mock_adapter.add.assert_not_called()
mock_adapter.query.assert_called_once_with("tell me about crewai repo")