mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
fix: Remove kwargs from all RagTools (#285)
* fix: remove kwargs from all (except mysql & pg) RagTools
The agent uses the tool description to decide what to propagate when a tool with **kwargs is found, but this often leads to failures during the tool invocation step.
This happens because the final description ends up like this:
```
CrewStructuredTool(name='Knowledge base', description='Tool Name: Knowledge base
Tool Arguments: {'query': {'description': None, 'type': 'str'}, 'kwargs': {'description': None, 'type': 'Any'}}
Tool Description: A knowledge base that can be used to answer questions.')
```
The agent then tries to infer and pass a kwargs parameter, which isn’t supported by the schema at all.
* feat: adding test to search tools
* feat: add db (chromadb folder) to .gitignore
* fix: fix github search integration
A few attributes were missing when calling the .add method: data_type and loader.
Also, update the query search according to the EmbedChain documentation, the query must include the type and repo keys
* fix: rollback YoutubeChannel paramenter
* chore: fix type hinting for CodeDocs search
* fix: ensure proper configuration when call `add`
According to the documentation, some search methods must be defined as either a loader or a data_type. This commit ensures that.
* build: add optional-dependencies for github and xml search
* test: mocking external requests from search_tool tests
* build: add pytest-recording as devDependencie
This commit is contained in:
@@ -31,30 +31,19 @@ class CodeDocsSearchTool(RagTool):
|
||||
def __init__(self, docs_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docs_url is not None:
|
||||
kwargs["data_type"] = DataType.DOCS_SITE
|
||||
self.add(docs_url)
|
||||
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
|
||||
self.args_schema = FixedCodeDocsSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "docs_url" in kwargs:
|
||||
self.add(kwargs["docs_url"])
|
||||
def add(self, docs_url: str) -> None:
|
||||
super().add(docs_url, data_type=DataType.DOCS_SITE)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
docs_url: Optional[str] = None,
|
||||
) -> str:
|
||||
if docs_url is not None:
|
||||
self.add(docs_url)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -31,30 +31,19 @@ class CSVSearchTool(RagTool):
|
||||
def __init__(self, csv: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if csv is not None:
|
||||
kwargs["data_type"] = DataType.CSV
|
||||
self.add(csv)
|
||||
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
|
||||
self.args_schema = FixedCSVSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "csv" in kwargs:
|
||||
self.add(kwargs["csv"])
|
||||
def add(self, csv: str) -> None:
|
||||
super().add(csv, data_type=DataType.CSV)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
csv: Optional[str] = None,
|
||||
) -> str:
|
||||
if csv is not None:
|
||||
self.add(csv)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -31,30 +31,22 @@ class DirectorySearchTool(RagTool):
|
||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
kwargs["loader"] = DirectoryLoader(config=dict(recursive=True))
|
||||
self.add(directory)
|
||||
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
|
||||
self.args_schema = FixedDirectorySearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "directory" in kwargs:
|
||||
self.add(kwargs["directory"])
|
||||
def add(self, directory: str) -> None:
|
||||
super().add(
|
||||
directory,
|
||||
loader=DirectoryLoader(config=dict(recursive=True)),
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
directory: Optional[str] = None,
|
||||
) -> str:
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -37,36 +37,19 @@ class DOCXSearchTool(RagTool):
|
||||
def __init__(self, docx: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docx is not None:
|
||||
kwargs["data_type"] = DataType.DOCX
|
||||
self.add(docx)
|
||||
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
|
||||
self.args_schema = FixedDOCXSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "docx" in kwargs:
|
||||
self.add(kwargs["docx"])
|
||||
def add(self, docx: str) -> None:
|
||||
super().add(docx, data_type=DataType.DOCX)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
search_query: str,
|
||||
docx: Optional[str] = None,
|
||||
) -> Any:
|
||||
search_query = kwargs.get("search_query")
|
||||
if search_query is None:
|
||||
search_query = kwargs.get("query")
|
||||
|
||||
docx = kwargs.get("docx")
|
||||
if docx is not None:
|
||||
self.add(docx)
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, List, Optional, Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from embedchain.loaders.github import GithubLoader
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
@@ -27,19 +27,29 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
|
||||
|
||||
class GithubSearchTool(RagTool):
|
||||
name: str = "Search a github repo's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
)
|
||||
summarize: bool = False
|
||||
gh_token: str
|
||||
args_schema: Type[BaseModel] = GithubSearchToolSchema
|
||||
content_types: List[str]
|
||||
content_types: List[str] = Field(
|
||||
default_factory=lambda: ["code", "repo", "pr", "issue"],
|
||||
description="Content types you want to be included search, options: [code, repo, pr, issue]",
|
||||
)
|
||||
_loader: GithubLoader | None = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, github_repo: Optional[str] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
github_repo: Optional[str] = None,
|
||||
content_types: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if github_repo is not None:
|
||||
kwargs["data_type"] = "github"
|
||||
kwargs["loader"] = GithubLoader(config={"token": self.gh_token})
|
||||
self._loader = GithubLoader(config={"token": self.gh_token})
|
||||
|
||||
self.add(repo=github_repo)
|
||||
if github_repo and content_types:
|
||||
self.add(repo=github_repo, content_types=content_types)
|
||||
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
self.args_schema = FixedGithubSearchToolSchema
|
||||
self._generate_description()
|
||||
@@ -47,26 +57,25 @@ class GithubSearchTool(RagTool):
|
||||
def add(
|
||||
self,
|
||||
repo: str,
|
||||
content_types: List[str] | None = None,
|
||||
**kwargs: Any,
|
||||
content_types: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
content_types = content_types or self.content_types
|
||||
|
||||
super().add(f"repo:{repo} type:{','.join(content_types)}", **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "github_repo" in kwargs:
|
||||
self.add(
|
||||
repo=kwargs["github_repo"], content_types=kwargs.get("content_types")
|
||||
)
|
||||
super().add(
|
||||
f"repo:{repo} type:{','.join(content_types)}",
|
||||
data_type="github",
|
||||
loader=self._loader,
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
github_repo: Optional[str] = None,
|
||||
content_types: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
if github_repo:
|
||||
self.add(
|
||||
repo=github_repo,
|
||||
content_types=content_types,
|
||||
)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -31,30 +31,16 @@ class JSONSearchTool(RagTool):
|
||||
def __init__(self, json_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if json_path is not None:
|
||||
kwargs["data_type"] = DataType.JSON
|
||||
self.add(json_path)
|
||||
self.description = f"A tool that can be used to semantic search a query the {json_path} JSON's content."
|
||||
self.args_schema = FixedJSONSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "json_path" in kwargs:
|
||||
self.add(kwargs["json_path"])
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
json_path: Optional[str] = None,
|
||||
) -> str:
|
||||
if json_path is not None:
|
||||
self.add(json_path)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -31,30 +31,19 @@ class MDXSearchTool(RagTool):
|
||||
def __init__(self, mdx: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if mdx is not None:
|
||||
kwargs["data_type"] = DataType.MDX
|
||||
self.add(mdx)
|
||||
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
|
||||
self.args_schema = FixedMDXSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "mdx" in kwargs:
|
||||
self.add(kwargs["mdx"])
|
||||
def add(self, mdx: str) -> None:
|
||||
super().add(mdx, data_type=DataType.MDX)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
mdx: Optional[str] = None,
|
||||
) -> str:
|
||||
if mdx is not None:
|
||||
self.add(mdx)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -30,39 +30,19 @@ class PDFSearchTool(RagTool):
|
||||
def __init__(self, pdf: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if pdf is not None:
|
||||
kwargs["data_type"] = DataType.PDF_FILE
|
||||
self.add(pdf)
|
||||
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
|
||||
self.args_schema = FixedPDFSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_default_adapter(self):
|
||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||
from embedchain import App
|
||||
def add(self, pdf: str) -> None:
|
||||
super().add(pdf, data_type=DataType.PDF_FILE)
|
||||
|
||||
from crewai_tools.adapters.pdf_embedchain_adapter import (
|
||||
PDFEmbedchainAdapter,
|
||||
)
|
||||
|
||||
app = App.from_config(config=self.config) if self.config else App()
|
||||
self.adapter = PDFEmbedchainAdapter(
|
||||
embedchain_app=app, summarize=self.summarize
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "pdf" in kwargs:
|
||||
self.add(kwargs["pdf"])
|
||||
pdf: Optional[str] = None,
|
||||
) -> str:
|
||||
if pdf is not None:
|
||||
self.add(pdf)
|
||||
return super()._run(query=query)
|
||||
|
||||
@@ -32,7 +32,9 @@ class PDFTextWritingTool(RagTool):
|
||||
"""A tool to add text to specific positions in a PDF, with custom font support."""
|
||||
|
||||
name: str = "PDF Text Writing Tool"
|
||||
description: str = "A tool that can write text to a specific position in a PDF document, with optional custom font embedding."
|
||||
description: str = (
|
||||
"A tool that can write text to a specific position in a PDF document, with optional custom font embedding."
|
||||
)
|
||||
args_schema: Type[BaseModel] = PDFTextWritingToolSchema
|
||||
|
||||
def run(
|
||||
@@ -45,7 +47,6 @@ class PDFTextWritingTool(RagTool):
|
||||
font_name: str = "F1",
|
||||
font_file: Optional[str] = None,
|
||||
page_number: int = 0,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
reader = PdfReader(pdf_path)
|
||||
writer = PdfWriter()
|
||||
|
||||
@@ -59,11 +59,5 @@ class RagTool(BaseTool):
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self._before_run(query, **kwargs)
|
||||
|
||||
) -> str:
|
||||
return f"Relevant Content:\n{self.adapter.query(query)}"
|
||||
|
||||
def _before_run(self, query, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -41,14 +41,15 @@ class SerplyJobSearchTool(RagTool):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
query: Optional[str] = None,
|
||||
search_query: Optional[str] = None,
|
||||
) -> str:
|
||||
query_payload = {}
|
||||
|
||||
if "query" in kwargs:
|
||||
query_payload["q"] = kwargs["query"]
|
||||
elif "search_query" in kwargs:
|
||||
query_payload["q"] = kwargs["search_query"]
|
||||
if query is not None:
|
||||
query_payload["q"] = query
|
||||
elif search_query is not None:
|
||||
query_payload["q"] = search_query
|
||||
|
||||
# build the url
|
||||
url = f"{self.request_url}{urlencode(query_payload)}"
|
||||
|
||||
@@ -18,7 +18,9 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel):
|
||||
|
||||
class SerplyWebpageToMarkdownTool(RagTool):
|
||||
name: str = "Webpage to Markdown"
|
||||
description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
|
||||
description: str = (
|
||||
"A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
|
||||
)
|
||||
args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema
|
||||
request_url: str = "https://api.serply.io/v1/request"
|
||||
proxy_location: Optional[str] = "US"
|
||||
@@ -39,9 +41,9 @@ class SerplyWebpageToMarkdownTool(RagTool):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
data = {"url": kwargs["url"], "method": "GET", "response_type": "markdown"}
|
||||
url: str,
|
||||
) -> str:
|
||||
data = {"url": url, "method": "GET", "response_type": "markdown"}
|
||||
response = requests.request(
|
||||
"POST", self.request_url, headers=self.headers, json=data
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -31,30 +30,16 @@ class TXTSearchTool(RagTool):
|
||||
def __init__(self, txt: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if txt is not None:
|
||||
kwargs["data_type"] = DataType.TEXT_FILE
|
||||
self.add(txt)
|
||||
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
|
||||
self.args_schema = FixedTXTSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "txt" in kwargs:
|
||||
self.add(kwargs["txt"])
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
txt: Optional[str] = None,
|
||||
) -> str:
|
||||
if txt is not None:
|
||||
self.add(txt)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -25,36 +25,27 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
|
||||
|
||||
class WebsiteSearchTool(RagTool):
|
||||
name: str = "Search in a specific website"
|
||||
description: str = "A tool that can be used to semantic search a query from a specific URL content."
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a specific URL content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
|
||||
|
||||
def __init__(self, website: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if website is not None:
|
||||
kwargs["data_type"] = DataType.WEB_PAGE
|
||||
self.add(website)
|
||||
self.description = f"A tool that can be used to semantic search a query from {website} website content."
|
||||
self.args_schema = FixedWebsiteSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "website" in kwargs:
|
||||
self.add(kwargs["website"])
|
||||
def add(self, website: str) -> None:
|
||||
super().add(website, data_type=DataType.WEB_PAGE)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
website: Optional[str] = None,
|
||||
) -> str:
|
||||
if website is not None:
|
||||
self.add(website)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -31,30 +31,16 @@ class XMLSearchTool(RagTool):
|
||||
def __init__(self, xml: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if xml is not None:
|
||||
kwargs["data_type"] = DataType.XML
|
||||
self.add(xml)
|
||||
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
|
||||
self.args_schema = FixedXMLSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "xml" in kwargs:
|
||||
self.add(kwargs["xml"])
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
xml: Optional[str] = None,
|
||||
) -> str:
|
||||
if xml is not None:
|
||||
self.add(xml)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -25,13 +25,14 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
|
||||
|
||||
class YoutubeChannelSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Channels content"
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if youtube_channel_handle is not None:
|
||||
kwargs["data_type"] = DataType.YOUTUBE_CHANNEL
|
||||
self.add(youtube_channel_handle)
|
||||
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
|
||||
self.args_schema = FixedYoutubeChannelSearchToolSchema
|
||||
@@ -40,23 +41,16 @@ class YoutubeChannelSearchTool(RagTool):
|
||||
def add(
|
||||
self,
|
||||
youtube_channel_handle: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if not youtube_channel_handle.startswith("@"):
|
||||
youtube_channel_handle = f"@{youtube_channel_handle}"
|
||||
super().add(youtube_channel_handle, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "youtube_channel_handle" in kwargs:
|
||||
self.add(kwargs["youtube_channel_handle"])
|
||||
super().add(youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
youtube_channel_handle: Optional[str] = None,
|
||||
) -> str:
|
||||
if youtube_channel_handle is not None:
|
||||
self.add(youtube_channel_handle)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -25,36 +25,27 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
|
||||
|
||||
class YoutubeVideoSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Video content"
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if youtube_video_url is not None:
|
||||
kwargs["data_type"] = DataType.YOUTUBE_VIDEO
|
||||
self.add(youtube_video_url)
|
||||
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
|
||||
self.args_schema = FixedYoutubeVideoSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "youtube_video_url" in kwargs:
|
||||
self.add(kwargs["youtube_video_url"])
|
||||
def add(self, youtube_video_url: str) -> None:
|
||||
super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
youtube_video_url: Optional[str] = None,
|
||||
) -> str:
|
||||
if youtube_video_url is not None:
|
||||
self.add(youtube_video_url)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
309
tests/tools/test_search_tools.py
Normal file
309
tests/tools/test_search_tools.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import ANY, MagicMock
|
||||
|
||||
import pytest
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
from crewai_tools.tools import (
|
||||
CodeDocsSearchTool,
|
||||
CSVSearchTool,
|
||||
DirectorySearchTool,
|
||||
DOCXSearchTool,
|
||||
GithubSearchTool,
|
||||
JSONSearchTool,
|
||||
MDXSearchTool,
|
||||
PDFSearchTool,
|
||||
TXTSearchTool,
|
||||
WebsiteSearchTool,
|
||||
XMLSearchTool,
|
||||
YoutubeChannelSearchTool,
|
||||
YoutubeVideoSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
pytestmark = [pytest.mark.vcr(filter_headers=["authorization"])]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
mock_adapter = MagicMock(spec=Adapter)
|
||||
return mock_adapter
|
||||
|
||||
|
||||
def test_directory_search_tool():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_file = Path(temp_dir) / "test.txt"
|
||||
test_file.write_text("This is a test file for directory search")
|
||||
|
||||
tool = DirectorySearchTool(directory=temp_dir)
|
||||
result = tool._run(search_query="test file")
|
||||
assert "test file" in result.lower()
|
||||
|
||||
|
||||
def test_pdf_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "this is a test"
|
||||
|
||||
tool = PDFSearchTool(pdf="test.pdf", adapter=mock_adapter)
|
||||
result = tool._run(query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = PDFSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(pdf="test.pdf", query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
|
||||
|
||||
def test_txt_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file:
|
||||
temp_file.write(b"This is a test file for txt search")
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
tool = TXTSearchTool()
|
||||
tool.add(temp_file_path)
|
||||
result = tool._run(search_query="test file")
|
||||
assert "test file" in result.lower()
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_docx_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "this is a test"
|
||||
|
||||
tool = DOCXSearchTool(docx="test.docx", adapter=mock_adapter)
|
||||
result = tool._run(search_query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = DOCXSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(docx="test.docx", search_query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
|
||||
|
||||
def test_json_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
|
||||
temp_file.write(b'{"test": "This is a test JSON file"}')
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
tool = JSONSearchTool()
|
||||
result = tool._run(search_query="test JSON", json_path=temp_file_path)
|
||||
assert "test json" in result.lower()
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_xml_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "this is a test"
|
||||
|
||||
tool = XMLSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(search_query="test XML", xml="test.xml")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.xml")
|
||||
mock_adapter.query.assert_called_once_with("test XML")
|
||||
|
||||
|
||||
def test_csv_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp_file:
|
||||
temp_file.write(b"name,description\ntest,This is a test CSV file")
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
tool = CSVSearchTool()
|
||||
tool.add(temp_file_path)
|
||||
result = tool._run(search_query="test CSV")
|
||||
assert "test csv" in result.lower()
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_mdx_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".mdx", delete=False) as temp_file:
|
||||
temp_file.write(b"# Test MDX\nThis is a test MDX file")
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
tool = MDXSearchTool()
|
||||
tool.add(temp_file_path)
|
||||
result = tool._run(search_query="test MDX")
|
||||
assert "test mdx" in result.lower()
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_website_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "this is a test"
|
||||
|
||||
website = "https://crewai.com"
|
||||
search_query = "what is crewai?"
|
||||
tool = WebsiteSearchTool(website=website, adapter=mock_adapter)
|
||||
result = tool._run(search_query=search_query)
|
||||
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?")
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
|
||||
|
||||
assert "this is a test" in result.lower()
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = WebsiteSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(website=website, search_query=search_query)
|
||||
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?")
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
|
||||
|
||||
assert "this is a test" in result.lower()
|
||||
|
||||
|
||||
def test_youtube_video_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "some video description"
|
||||
|
||||
youtube_video_url = "https://www.youtube.com/watch?v=sample-video-id"
|
||||
search_query = "what is the video about?"
|
||||
tool = YoutubeVideoSearchTool(
|
||||
youtube_video_url=youtube_video_url,
|
||||
adapter=mock_adapter,
|
||||
)
|
||||
result = tool._run(search_query=search_query)
|
||||
assert "some video description" in result
|
||||
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = YoutubeVideoSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(youtube_video_url=youtube_video_url, search_query=search_query)
|
||||
assert "some video description" in result
|
||||
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
|
||||
def test_youtube_channel_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "channel description"
|
||||
|
||||
youtube_channel_handle = "@crewai"
|
||||
search_query = "what is the channel about?"
|
||||
tool = YoutubeChannelSearchTool(
|
||||
youtube_channel_handle=youtube_channel_handle, adapter=mock_adapter
|
||||
)
|
||||
result = tool._run(search_query=search_query)
|
||||
assert "channel description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = YoutubeChannelSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(
|
||||
youtube_channel_handle=youtube_channel_handle, search_query=search_query
|
||||
)
|
||||
assert "channel description" in result
|
||||
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
|
||||
def test_code_docs_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "test documentation"
|
||||
|
||||
docs_url = "https://crewai.com/any-docs-url"
|
||||
search_query = "test documentation"
|
||||
tool = CodeDocsSearchTool(docs_url=docs_url, adapter=mock_adapter)
|
||||
result = tool._run(search_query=search_query)
|
||||
assert "test documentation" in result
|
||||
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = CodeDocsSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(docs_url=docs_url, search_query=search_query)
|
||||
assert "test documentation" in result
|
||||
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
|
||||
|
||||
def test_github_search_tool(mock_adapter):
|
||||
mock_adapter.query.return_value = "repo description"
|
||||
|
||||
# ensure the provided repo and content types are used after initialization
|
||||
tool = GithubSearchTool(
|
||||
gh_token="test_token",
|
||||
github_repo="crewai/crewai",
|
||||
content_types=["code"],
|
||||
adapter=mock_adapter,
|
||||
)
|
||||
result = tool._run(search_query="tell me about crewai repo")
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code", data_type="github", loader=ANY
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
|
||||
# ensure content types provided by run call is used
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
|
||||
result = tool._run(
|
||||
github_repo="crewai/crewai",
|
||||
content_types=["code", "issue"],
|
||||
search_query="tell me about crewai repo",
|
||||
)
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code,issue", data_type="github", loader=ANY
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
|
||||
# ensure default content types are used if not provided
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
|
||||
result = tool._run(
|
||||
github_repo="crewai/crewai",
|
||||
search_query="tell me about crewai repo",
|
||||
)
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code,repo,pr,issue", data_type="github", loader=ANY
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
|
||||
# ensure nothing is added if no repo is provided
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
|
||||
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
|
||||
result = tool._run(search_query="tell me about crewai repo")
|
||||
mock_adapter.add.assert_not_called()
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
Reference in New Issue
Block a user