fix: Remove kwargs from all RagTools (#285)

* fix: remove kwargs from all (except mysql & pg) RagTools

The agent uses the tool description to decide what to propagate when a tool with **kwargs is found, but this often leads to failures during the tool invocation step.

This happens because the final description ends up like this:

```
CrewStructuredTool(name='Knowledge base', description='Tool Name: Knowledge base
Tool Arguments: {'query': {'description': None, 'type': 'str'}, 'kwargs': {'description': None, 'type': 'Any'}}
Tool Description: A knowledge base that can be used to answer questions.')
```

The agent then tries to infer and pass a kwargs parameter, which isn’t supported by the schema at all.

* feat: adding test to search tools

* feat: add db (chromadb folder) to .gitignore

* fix: fix github search integration

A few attributes were missing when calling the .add method: data_type and loader.

Also, update the query search according to the EmbedChain documentation, the query must include the type and repo keys

* fix: rollback YoutubeChannel paramenter

* chore: fix type hinting for CodeDocs search

* fix: ensure proper configuration when call `add`

According to the documentation, some search methods must be defined as either a loader or a data_type. This commit ensures that.

* build: add optional-dependencies for github and xml search

* test: mocking external requests from search_tool tests

* build: add pytest-recording as devDependencie
This commit is contained in:
Lucas Gomide
2025-05-05 15:15:50 -03:00
committed by GitHub
parent 93d043bcd4
commit fd4ef4f47a
23 changed files with 2051 additions and 279 deletions

View File

@@ -31,30 +31,19 @@ class CodeDocsSearchTool(RagTool):
def __init__(self, docs_url: Optional[str] = None, **kwargs): def __init__(self, docs_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if docs_url is not None: if docs_url is not None:
kwargs["data_type"] = DataType.DOCS_SITE
self.add(docs_url) self.add(docs_url)
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content." self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
self.args_schema = FixedCodeDocsSearchToolSchema self.args_schema = FixedCodeDocsSearchToolSchema
self._generate_description() self._generate_description()
def add( def add(self, docs_url: str) -> None:
self, super().add(docs_url, data_type=DataType.DOCS_SITE)
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "docs_url" in kwargs:
self.add(kwargs["docs_url"])
def _run( def _run(
self, self,
search_query: str, search_query: str,
**kwargs: Any, docs_url: Optional[str] = None,
) -> Any: ) -> str:
return super()._run(query=search_query, **kwargs) if docs_url is not None:
self.add(docs_url)
return super()._run(query=search_query)

View File

@@ -31,30 +31,19 @@ class CSVSearchTool(RagTool):
def __init__(self, csv: Optional[str] = None, **kwargs): def __init__(self, csv: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if csv is not None: if csv is not None:
kwargs["data_type"] = DataType.CSV
self.add(csv) self.add(csv)
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content." self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
self.args_schema = FixedCSVSearchToolSchema self.args_schema = FixedCSVSearchToolSchema
self._generate_description() self._generate_description()
def add( def add(self, csv: str) -> None:
self, super().add(csv, data_type=DataType.CSV)
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "csv" in kwargs:
self.add(kwargs["csv"])
def _run( def _run(
self, self,
search_query: str, search_query: str,
**kwargs: Any, csv: Optional[str] = None,
) -> Any: ) -> str:
return super()._run(query=search_query, **kwargs) if csv is not None:
self.add(csv)
return super()._run(query=search_query)

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type from typing import Optional, Type
from embedchain.loaders.directory_loader import DirectoryLoader from embedchain.loaders.directory_loader import DirectoryLoader
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -31,30 +31,22 @@ class DirectorySearchTool(RagTool):
def __init__(self, directory: Optional[str] = None, **kwargs): def __init__(self, directory: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if directory is not None: if directory is not None:
kwargs["loader"] = DirectoryLoader(config=dict(recursive=True))
self.add(directory) self.add(directory)
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content." self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
self.args_schema = FixedDirectorySearchToolSchema self.args_schema = FixedDirectorySearchToolSchema
self._generate_description() self._generate_description()
def add( def add(self, directory: str) -> None:
self, super().add(
*args: Any, directory,
**kwargs: Any, loader=DirectoryLoader(config=dict(recursive=True)),
) -> None: )
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "directory" in kwargs:
self.add(kwargs["directory"])
def _run( def _run(
self, self,
search_query: str, search_query: str,
**kwargs: Any, directory: Optional[str] = None,
) -> Any: ) -> str:
return super()._run(query=search_query, **kwargs) if directory is not None:
self.add(directory)
return super()._run(query=search_query)

View File

@@ -37,36 +37,19 @@ class DOCXSearchTool(RagTool):
def __init__(self, docx: Optional[str] = None, **kwargs): def __init__(self, docx: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if docx is not None: if docx is not None:
kwargs["data_type"] = DataType.DOCX
self.add(docx) self.add(docx)
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content." self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
self.args_schema = FixedDOCXSearchToolSchema self.args_schema = FixedDOCXSearchToolSchema
self._generate_description() self._generate_description()
def add( def add(self, docx: str) -> None:
self, super().add(docx, data_type=DataType.DOCX)
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "docx" in kwargs:
self.add(kwargs["docx"])
def _run( def _run(
self, self,
**kwargs: Any, search_query: str,
docx: Optional[str] = None,
) -> Any: ) -> Any:
search_query = kwargs.get("search_query")
if search_query is None:
search_query = kwargs.get("query")
docx = kwargs.get("docx")
if docx is not None: if docx is not None:
self.add(docx) self.add(docx)
return super()._run(query=search_query, **kwargs) return super()._run(query=search_query)

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Optional, Type from typing import List, Optional, Type
from embedchain.loaders.github import GithubLoader from embedchain.loaders.github import GithubLoader
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, PrivateAttr
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
@@ -27,19 +27,29 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
class GithubSearchTool(RagTool): class GithubSearchTool(RagTool):
name: str = "Search a github repo's content" name: str = "Search a github repo's content"
description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities." description: str = (
"A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
)
summarize: bool = False summarize: bool = False
gh_token: str gh_token: str
args_schema: Type[BaseModel] = GithubSearchToolSchema args_schema: Type[BaseModel] = GithubSearchToolSchema
content_types: List[str] content_types: List[str] = Field(
default_factory=lambda: ["code", "repo", "pr", "issue"],
description="Content types you want to be included search, options: [code, repo, pr, issue]",
)
_loader: GithubLoader | None = PrivateAttr(default=None)
def __init__(self, github_repo: Optional[str] = None, **kwargs): def __init__(
self,
github_repo: Optional[str] = None,
content_types: Optional[List[str]] = None,
**kwargs,
):
super().__init__(**kwargs) super().__init__(**kwargs)
if github_repo is not None: self._loader = GithubLoader(config={"token": self.gh_token})
kwargs["data_type"] = "github"
kwargs["loader"] = GithubLoader(config={"token": self.gh_token})
self.add(repo=github_repo) if github_repo and content_types:
self.add(repo=github_repo, content_types=content_types)
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities." self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
self.args_schema = FixedGithubSearchToolSchema self.args_schema = FixedGithubSearchToolSchema
self._generate_description() self._generate_description()
@@ -47,26 +57,25 @@ class GithubSearchTool(RagTool):
def add( def add(
self, self,
repo: str, repo: str,
content_types: List[str] | None = None, content_types: Optional[List[str]] = None,
**kwargs: Any,
) -> None: ) -> None:
content_types = content_types or self.content_types content_types = content_types or self.content_types
super().add(f"repo:{repo} type:{','.join(content_types)}", **kwargs) super().add(
f"repo:{repo} type:{','.join(content_types)}",
def _before_run( data_type="github",
self, loader=self._loader,
query: str, )
**kwargs: Any,
) -> Any:
if "github_repo" in kwargs:
self.add(
repo=kwargs["github_repo"], content_types=kwargs.get("content_types")
)
def _run( def _run(
self, self,
search_query: str, search_query: str,
**kwargs: Any, github_repo: Optional[str] = None,
) -> Any: content_types: Optional[List[str]] = None,
return super()._run(query=search_query, **kwargs) ) -> str:
if github_repo:
self.add(
repo=github_repo,
content_types=content_types,
)
return super()._run(query=search_query)

View File

@@ -31,30 +31,16 @@ class JSONSearchTool(RagTool):
def __init__(self, json_path: Optional[str] = None, **kwargs): def __init__(self, json_path: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if json_path is not None: if json_path is not None:
kwargs["data_type"] = DataType.JSON
self.add(json_path) self.add(json_path)
self.description = f"A tool that can be used to semantic search a query the {json_path} JSON's content." self.description = f"A tool that can be used to semantic search a query the {json_path} JSON's content."
self.args_schema = FixedJSONSearchToolSchema self.args_schema = FixedJSONSearchToolSchema
self._generate_description() self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "json_path" in kwargs:
self.add(kwargs["json_path"])
def _run( def _run(
self, self,
search_query: str, search_query: str,
**kwargs: Any, json_path: Optional[str] = None,
) -> Any: ) -> str:
return super()._run(query=search_query, **kwargs) if json_path is not None:
self.add(json_path)
return super()._run(query=search_query)

View File

@@ -31,30 +31,19 @@ class MDXSearchTool(RagTool):
def __init__(self, mdx: Optional[str] = None, **kwargs): def __init__(self, mdx: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if mdx is not None: if mdx is not None:
kwargs["data_type"] = DataType.MDX
self.add(mdx) self.add(mdx)
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content." self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
self.args_schema = FixedMDXSearchToolSchema self.args_schema = FixedMDXSearchToolSchema
self._generate_description() self._generate_description()
def add( def add(self, mdx: str) -> None:
self, super().add(mdx, data_type=DataType.MDX)
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "mdx" in kwargs:
self.add(kwargs["mdx"])
def _run( def _run(
self, self,
search_query: str, search_query: str,
**kwargs: Any, mdx: Optional[str] = None,
) -> Any: ) -> str:
return super()._run(query=search_query, **kwargs) if mdx is not None:
self.add(mdx)
return super()._run(query=search_query)

View File

@@ -30,39 +30,19 @@ class PDFSearchTool(RagTool):
def __init__(self, pdf: Optional[str] = None, **kwargs): def __init__(self, pdf: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if pdf is not None: if pdf is not None:
kwargs["data_type"] = DataType.PDF_FILE
self.add(pdf) self.add(pdf)
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content." self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
self.args_schema = FixedPDFSearchToolSchema self.args_schema = FixedPDFSearchToolSchema
self._generate_description() self._generate_description()
@model_validator(mode="after") def add(self, pdf: str) -> None:
def _set_default_adapter(self): super().add(pdf, data_type=DataType.PDF_FILE)
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
from embedchain import App
from crewai_tools.adapters.pdf_embedchain_adapter import ( def _run(
PDFEmbedchainAdapter,
)
app = App.from_config(config=self.config) if self.config else App()
self.adapter = PDFEmbedchainAdapter(
embedchain_app=app, summarize=self.summarize
)
return self
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self, self,
query: str, query: str,
**kwargs: Any, pdf: Optional[str] = None,
) -> Any: ) -> str:
if "pdf" in kwargs: if pdf is not None:
self.add(kwargs["pdf"]) self.add(pdf)
return super()._run(query=query)

View File

@@ -32,7 +32,9 @@ class PDFTextWritingTool(RagTool):
"""A tool to add text to specific positions in a PDF, with custom font support.""" """A tool to add text to specific positions in a PDF, with custom font support."""
name: str = "PDF Text Writing Tool" name: str = "PDF Text Writing Tool"
description: str = "A tool that can write text to a specific position in a PDF document, with optional custom font embedding." description: str = (
"A tool that can write text to a specific position in a PDF document, with optional custom font embedding."
)
args_schema: Type[BaseModel] = PDFTextWritingToolSchema args_schema: Type[BaseModel] = PDFTextWritingToolSchema
def run( def run(
@@ -45,7 +47,6 @@ class PDFTextWritingTool(RagTool):
font_name: str = "F1", font_name: str = "F1",
font_file: Optional[str] = None, font_file: Optional[str] = None,
page_number: int = 0, page_number: int = 0,
**kwargs,
) -> str: ) -> str:
reader = PdfReader(pdf_path) reader = PdfReader(pdf_path)
writer = PdfWriter() writer = PdfWriter()

View File

@@ -59,11 +59,5 @@ class RagTool(BaseTool):
def _run( def _run(
self, self,
query: str, query: str,
**kwargs: Any, ) -> str:
) -> Any:
self._before_run(query, **kwargs)
return f"Relevant Content:\n{self.adapter.query(query)}" return f"Relevant Content:\n{self.adapter.query(query)}"
def _before_run(self, query, **kwargs):
pass

View File

@@ -41,14 +41,15 @@ class SerplyJobSearchTool(RagTool):
def _run( def _run(
self, self,
**kwargs: Any, query: Optional[str] = None,
) -> Any: search_query: Optional[str] = None,
) -> str:
query_payload = {} query_payload = {}
if "query" in kwargs: if query is not None:
query_payload["q"] = kwargs["query"] query_payload["q"] = query
elif "search_query" in kwargs: elif search_query is not None:
query_payload["q"] = kwargs["search_query"] query_payload["q"] = search_query
# build the url # build the url
url = f"{self.request_url}{urlencode(query_payload)}" url = f"{self.request_url}{urlencode(query_payload)}"

View File

@@ -18,7 +18,9 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel):
class SerplyWebpageToMarkdownTool(RagTool): class SerplyWebpageToMarkdownTool(RagTool):
name: str = "Webpage to Markdown" name: str = "Webpage to Markdown"
description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand" description: str = (
"A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
)
args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema
request_url: str = "https://api.serply.io/v1/request" request_url: str = "https://api.serply.io/v1/request"
proxy_location: Optional[str] = "US" proxy_location: Optional[str] = "US"
@@ -39,9 +41,9 @@ class SerplyWebpageToMarkdownTool(RagTool):
def _run( def _run(
self, self,
**kwargs: Any, url: str,
) -> Any: ) -> str:
data = {"url": kwargs["url"], "method": "GET", "response_type": "markdown"} data = {"url": url, "method": "GET", "response_type": "markdown"}
response = requests.request( response = requests.request(
"POST", self.request_url, headers=self.headers, json=data "POST", self.request_url, headers=self.headers, json=data
) )

View File

@@ -1,6 +1,5 @@
from typing import Any, Optional, Type from typing import Optional, Type
from embedchain.models.data_type import DataType
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
@@ -31,30 +30,16 @@ class TXTSearchTool(RagTool):
def __init__(self, txt: Optional[str] = None, **kwargs): def __init__(self, txt: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if txt is not None: if txt is not None:
kwargs["data_type"] = DataType.TEXT_FILE
self.add(txt) self.add(txt)
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content." self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
self.args_schema = FixedTXTSearchToolSchema self.args_schema = FixedTXTSearchToolSchema
self._generate_description() self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "txt" in kwargs:
self.add(kwargs["txt"])
def _run( def _run(
self, self,
search_query: str, search_query: str,
**kwargs: Any, txt: Optional[str] = None,
) -> Any: ) -> str:
return super()._run(query=search_query, **kwargs) if txt is not None:
self.add(txt)
return super()._run(query=search_query)

View File

@@ -25,36 +25,27 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
class WebsiteSearchTool(RagTool): class WebsiteSearchTool(RagTool):
name: str = "Search in a specific website" name: str = "Search in a specific website"
description: str = "A tool that can be used to semantic search a query from a specific URL content." description: str = (
"A tool that can be used to semantic search a query from a specific URL content."
)
args_schema: Type[BaseModel] = WebsiteSearchToolSchema args_schema: Type[BaseModel] = WebsiteSearchToolSchema
def __init__(self, website: Optional[str] = None, **kwargs): def __init__(self, website: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if website is not None: if website is not None:
kwargs["data_type"] = DataType.WEB_PAGE
self.add(website) self.add(website)
self.description = f"A tool that can be used to semantic search a query from {website} website content." self.description = f"A tool that can be used to semantic search a query from {website} website content."
self.args_schema = FixedWebsiteSearchToolSchema self.args_schema = FixedWebsiteSearchToolSchema
self._generate_description() self._generate_description()
def add( def add(self, website: str) -> None:
self, super().add(website, data_type=DataType.WEB_PAGE)
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "website" in kwargs:
self.add(kwargs["website"])
def _run( def _run(
self, self,
search_query: str, search_query: str,
**kwargs: Any, website: Optional[str] = None,
) -> Any: ) -> str:
return super()._run(query=search_query, **kwargs) if website is not None:
self.add(website)
return super()._run(query=search_query)

View File

@@ -31,30 +31,16 @@ class XMLSearchTool(RagTool):
def __init__(self, xml: Optional[str] = None, **kwargs): def __init__(self, xml: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if xml is not None: if xml is not None:
kwargs["data_type"] = DataType.XML
self.add(xml) self.add(xml)
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content." self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
self.args_schema = FixedXMLSearchToolSchema self.args_schema = FixedXMLSearchToolSchema
self._generate_description() self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "xml" in kwargs:
self.add(kwargs["xml"])
def _run( def _run(
self, self,
search_query: str, search_query: str,
**kwargs: Any, xml: Optional[str] = None,
) -> Any: ) -> str:
return super()._run(query=search_query, **kwargs) if xml is not None:
self.add(xml)
return super()._run(query=search_query)

View File

@@ -25,13 +25,14 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
class YoutubeChannelSearchTool(RagTool): class YoutubeChannelSearchTool(RagTool):
name: str = "Search a Youtube Channels content" name: str = "Search a Youtube Channels content"
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content." description: str = (
"A tool that can be used to semantic search a query from a Youtube Channels content."
)
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs): def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if youtube_channel_handle is not None: if youtube_channel_handle is not None:
kwargs["data_type"] = DataType.YOUTUBE_CHANNEL
self.add(youtube_channel_handle) self.add(youtube_channel_handle)
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content." self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
self.args_schema = FixedYoutubeChannelSearchToolSchema self.args_schema = FixedYoutubeChannelSearchToolSchema
@@ -40,23 +41,16 @@ class YoutubeChannelSearchTool(RagTool):
def add( def add(
self, self,
youtube_channel_handle: str, youtube_channel_handle: str,
**kwargs: Any,
) -> None: ) -> None:
if not youtube_channel_handle.startswith("@"): if not youtube_channel_handle.startswith("@"):
youtube_channel_handle = f"@{youtube_channel_handle}" youtube_channel_handle = f"@{youtube_channel_handle}"
super().add(youtube_channel_handle, **kwargs) super().add(youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "youtube_channel_handle" in kwargs:
self.add(kwargs["youtube_channel_handle"])
def _run( def _run(
self, self,
search_query: str, search_query: str,
**kwargs: Any, youtube_channel_handle: Optional[str] = None,
) -> Any: ) -> str:
return super()._run(query=search_query, **kwargs) if youtube_channel_handle is not None:
self.add(youtube_channel_handle)
return super()._run(query=search_query)

View File

@@ -25,36 +25,27 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
class YoutubeVideoSearchTool(RagTool): class YoutubeVideoSearchTool(RagTool):
name: str = "Search a Youtube Video content" name: str = "Search a Youtube Video content"
description: str = "A tool that can be used to semantic search a query from a Youtube Video content." description: str = (
"A tool that can be used to semantic search a query from a Youtube Video content."
)
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs): def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if youtube_video_url is not None: if youtube_video_url is not None:
kwargs["data_type"] = DataType.YOUTUBE_VIDEO
self.add(youtube_video_url) self.add(youtube_video_url)
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content." self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
self.args_schema = FixedYoutubeVideoSearchToolSchema self.args_schema = FixedYoutubeVideoSearchToolSchema
self._generate_description() self._generate_description()
def add( def add(self, youtube_video_url: str) -> None:
self, super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "youtube_video_url" in kwargs:
self.add(kwargs["youtube_video_url"])
def _run( def _run(
self, self,
search_query: str, search_query: str,
**kwargs: Any, youtube_video_url: Optional[str] = None,
) -> Any: ) -> str:
return super()._run(query=search_query, **kwargs) if youtube_video_url is not None:
self.add(youtube_video_url)
return super()._run(query=search_query)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,309 @@
import os
import tempfile
from pathlib import Path
from unittest.mock import ANY, MagicMock
import pytest
from embedchain.models.data_type import DataType
from crewai_tools.tools import (
CodeDocsSearchTool,
CSVSearchTool,
DirectorySearchTool,
DOCXSearchTool,
GithubSearchTool,
JSONSearchTool,
MDXSearchTool,
PDFSearchTool,
TXTSearchTool,
WebsiteSearchTool,
XMLSearchTool,
YoutubeChannelSearchTool,
YoutubeVideoSearchTool,
)
from crewai_tools.tools.rag.rag_tool import Adapter
pytestmark = [pytest.mark.vcr(filter_headers=["authorization"])]
@pytest.fixture
def mock_adapter():
mock_adapter = MagicMock(spec=Adapter)
return mock_adapter
def test_directory_search_tool():
with tempfile.TemporaryDirectory() as temp_dir:
test_file = Path(temp_dir) / "test.txt"
test_file.write_text("This is a test file for directory search")
tool = DirectorySearchTool(directory=temp_dir)
result = tool._run(search_query="test file")
assert "test file" in result.lower()
def test_pdf_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = PDFSearchTool(pdf="test.pdf", adapter=mock_adapter)
result = tool._run(query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
mock_adapter.query.assert_called_once_with("test content")
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = PDFSearchTool(adapter=mock_adapter)
result = tool._run(pdf="test.pdf", query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
mock_adapter.query.assert_called_once_with("test content")
def test_txt_search_tool():
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file:
temp_file.write(b"This is a test file for txt search")
temp_file_path = temp_file.name
try:
tool = TXTSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test file")
assert "test file" in result.lower()
finally:
os.unlink(temp_file_path)
def test_docx_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = DOCXSearchTool(docx="test.docx", adapter=mock_adapter)
result = tool._run(search_query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
mock_adapter.query.assert_called_once_with("test content")
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = DOCXSearchTool(adapter=mock_adapter)
result = tool._run(docx="test.docx", search_query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
mock_adapter.query.assert_called_once_with("test content")
def test_json_search_tool():
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
temp_file.write(b'{"test": "This is a test JSON file"}')
temp_file_path = temp_file.name
try:
tool = JSONSearchTool()
result = tool._run(search_query="test JSON", json_path=temp_file_path)
assert "test json" in result.lower()
finally:
os.unlink(temp_file_path)
def test_xml_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = XMLSearchTool(adapter=mock_adapter)
result = tool._run(search_query="test XML", xml="test.xml")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.xml")
mock_adapter.query.assert_called_once_with("test XML")
def test_csv_search_tool():
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp_file:
temp_file.write(b"name,description\ntest,This is a test CSV file")
temp_file_path = temp_file.name
try:
tool = CSVSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test CSV")
assert "test csv" in result.lower()
finally:
os.unlink(temp_file_path)
def test_mdx_search_tool():
with tempfile.NamedTemporaryFile(suffix=".mdx", delete=False) as temp_file:
temp_file.write(b"# Test MDX\nThis is a test MDX file")
temp_file_path = temp_file.name
try:
tool = MDXSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test MDX")
assert "test mdx" in result.lower()
finally:
os.unlink(temp_file_path)
def test_website_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
website = "https://crewai.com"
search_query = "what is crewai?"
tool = WebsiteSearchTool(website=website, adapter=mock_adapter)
result = tool._run(search_query=search_query)
mock_adapter.query.assert_called_once_with("what is crewai?")
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
assert "this is a test" in result.lower()
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = WebsiteSearchTool(adapter=mock_adapter)
result = tool._run(website=website, search_query=search_query)
mock_adapter.query.assert_called_once_with("what is crewai?")
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
assert "this is a test" in result.lower()
def test_youtube_video_search_tool(mock_adapter):
mock_adapter.query.return_value = "some video description"
youtube_video_url = "https://www.youtube.com/watch?v=sample-video-id"
search_query = "what is the video about?"
tool = YoutubeVideoSearchTool(
youtube_video_url=youtube_video_url,
adapter=mock_adapter,
)
result = tool._run(search_query=search_query)
assert "some video description" in result
mock_adapter.add.assert_called_once_with(
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
)
mock_adapter.query.assert_called_once_with(search_query)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = YoutubeVideoSearchTool(adapter=mock_adapter)
result = tool._run(youtube_video_url=youtube_video_url, search_query=search_query)
assert "some video description" in result
mock_adapter.add.assert_called_once_with(
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
)
mock_adapter.query.assert_called_once_with(search_query)
def test_youtube_channel_search_tool(mock_adapter):
mock_adapter.query.return_value = "channel description"
youtube_channel_handle = "@crewai"
search_query = "what is the channel about?"
tool = YoutubeChannelSearchTool(
youtube_channel_handle=youtube_channel_handle, adapter=mock_adapter
)
result = tool._run(search_query=search_query)
assert "channel description" in result
mock_adapter.add.assert_called_once_with(
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
)
mock_adapter.query.assert_called_once_with(search_query)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = YoutubeChannelSearchTool(adapter=mock_adapter)
result = tool._run(
youtube_channel_handle=youtube_channel_handle, search_query=search_query
)
assert "channel description" in result
mock_adapter.add.assert_called_once_with(
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
)
mock_adapter.query.assert_called_once_with(search_query)
def test_code_docs_search_tool(mock_adapter):
mock_adapter.query.return_value = "test documentation"
docs_url = "https://crewai.com/any-docs-url"
search_query = "test documentation"
tool = CodeDocsSearchTool(docs_url=docs_url, adapter=mock_adapter)
result = tool._run(search_query=search_query)
assert "test documentation" in result
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
mock_adapter.query.assert_called_once_with(search_query)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = CodeDocsSearchTool(adapter=mock_adapter)
result = tool._run(docs_url=docs_url, search_query=search_query)
assert "test documentation" in result
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
mock_adapter.query.assert_called_once_with(search_query)
def test_github_search_tool(mock_adapter):
mock_adapter.query.return_value = "repo description"
# ensure the provided repo and content types are used after initialization
tool = GithubSearchTool(
gh_token="test_token",
github_repo="crewai/crewai",
content_types=["code"],
adapter=mock_adapter,
)
result = tool._run(search_query="tell me about crewai repo")
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code", data_type="github", loader=ANY
)
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
# ensure content types provided by run call is used
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(
github_repo="crewai/crewai",
content_types=["code", "issue"],
search_query="tell me about crewai repo",
)
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code,issue", data_type="github", loader=ANY
)
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
# ensure default content types are used if not provided
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(
github_repo="crewai/crewai",
search_query="tell me about crewai repo",
)
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code,repo,pr,issue", data_type="github", loader=ANY
)
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
# ensure nothing is added if no repo is provided
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(search_query="tell me about crewai repo")
mock_adapter.add.assert_not_called()
mock_adapter.query.assert_called_once_with("tell me about crewai repo")