fix: Remove kwargs from all RagTools (#285)

* fix: remove kwargs from all (except mysql & pg) RagTools

The agent uses the tool description to decide what to propagate when a tool with **kwargs is found, but this often leads to failures during the tool invocation step.

This happens because the final description ends up like this:

```
CrewStructuredTool(name='Knowledge base', description='Tool Name: Knowledge base
Tool Arguments: {'query': {'description': None, 'type': 'str'}, 'kwargs': {'description': None, 'type': 'Any'}}
Tool Description: A knowledge base that can be used to answer questions.')
```

The agent then tries to infer and pass a kwargs parameter, which isn’t supported by the schema at all.

* feat: adding test to search tools

* feat: add db (chromadb folder) to .gitignore

* fix: fix github search integration

A few attributes were missing when calling the .add method: data_type and loader.

Also, update the query search according to the EmbedChain documentation, the query must include the type and repo keys

* fix: rollback YoutubeChannel paramenter

* chore: fix type hinting for CodeDocs search

* fix: ensure proper configuration when call `add`

According to the documentation, some search methods must be defined as either a loader or a data_type. This commit ensures that.

* build: add optional-dependencies for github and xml search

* test: mocking external requests from search_tool tests

* build: add pytest-recording as devDependencie
This commit is contained in:
Lucas Gomide
2025-05-05 15:15:50 -03:00
committed by GitHub
parent 93d043bcd4
commit fd4ef4f47a
23 changed files with 2051 additions and 279 deletions

View File

@@ -31,30 +31,19 @@ class CodeDocsSearchTool(RagTool):
def __init__(self, docs_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if docs_url is not None:
kwargs["data_type"] = DataType.DOCS_SITE
self.add(docs_url)
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
self.args_schema = FixedCodeDocsSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "docs_url" in kwargs:
self.add(kwargs["docs_url"])
def add(self, docs_url: str) -> None:
super().add(docs_url, data_type=DataType.DOCS_SITE)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
docs_url: Optional[str] = None,
) -> str:
if docs_url is not None:
self.add(docs_url)
return super()._run(query=search_query)

View File

@@ -31,30 +31,19 @@ class CSVSearchTool(RagTool):
def __init__(self, csv: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if csv is not None:
kwargs["data_type"] = DataType.CSV
self.add(csv)
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
self.args_schema = FixedCSVSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "csv" in kwargs:
self.add(kwargs["csv"])
def add(self, csv: str) -> None:
super().add(csv, data_type=DataType.CSV)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
csv: Optional[str] = None,
) -> str:
if csv is not None:
self.add(csv)
return super()._run(query=search_query)

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type
from typing import Optional, Type
from embedchain.loaders.directory_loader import DirectoryLoader
from pydantic import BaseModel, Field
@@ -31,30 +31,22 @@ class DirectorySearchTool(RagTool):
def __init__(self, directory: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if directory is not None:
kwargs["loader"] = DirectoryLoader(config=dict(recursive=True))
self.add(directory)
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
self.args_schema = FixedDirectorySearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "directory" in kwargs:
self.add(kwargs["directory"])
def add(self, directory: str) -> None:
super().add(
directory,
loader=DirectoryLoader(config=dict(recursive=True)),
)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
directory: Optional[str] = None,
) -> str:
if directory is not None:
self.add(directory)
return super()._run(query=search_query)

View File

@@ -37,36 +37,19 @@ class DOCXSearchTool(RagTool):
def __init__(self, docx: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if docx is not None:
kwargs["data_type"] = DataType.DOCX
self.add(docx)
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
self.args_schema = FixedDOCXSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "docx" in kwargs:
self.add(kwargs["docx"])
def add(self, docx: str) -> None:
super().add(docx, data_type=DataType.DOCX)
def _run(
self,
**kwargs: Any,
search_query: str,
docx: Optional[str] = None,
) -> Any:
search_query = kwargs.get("search_query")
if search_query is None:
search_query = kwargs.get("query")
docx = kwargs.get("docx")
if docx is not None:
self.add(docx)
return super()._run(query=search_query, **kwargs)
return super()._run(query=search_query)

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Optional, Type
from typing import List, Optional, Type
from embedchain.loaders.github import GithubLoader
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from ..rag.rag_tool import RagTool
@@ -27,19 +27,29 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
class GithubSearchTool(RagTool):
name: str = "Search a github repo's content"
description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
description: str = (
"A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
)
summarize: bool = False
gh_token: str
args_schema: Type[BaseModel] = GithubSearchToolSchema
content_types: List[str]
content_types: List[str] = Field(
default_factory=lambda: ["code", "repo", "pr", "issue"],
description="Content types you want to be included search, options: [code, repo, pr, issue]",
)
_loader: GithubLoader | None = PrivateAttr(default=None)
def __init__(self, github_repo: Optional[str] = None, **kwargs):
def __init__(
self,
github_repo: Optional[str] = None,
content_types: Optional[List[str]] = None,
**kwargs,
):
super().__init__(**kwargs)
if github_repo is not None:
kwargs["data_type"] = "github"
kwargs["loader"] = GithubLoader(config={"token": self.gh_token})
self._loader = GithubLoader(config={"token": self.gh_token})
self.add(repo=github_repo)
if github_repo and content_types:
self.add(repo=github_repo, content_types=content_types)
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
self.args_schema = FixedGithubSearchToolSchema
self._generate_description()
@@ -47,26 +57,25 @@ class GithubSearchTool(RagTool):
def add(
self,
repo: str,
content_types: List[str] | None = None,
**kwargs: Any,
content_types: Optional[List[str]] = None,
) -> None:
content_types = content_types or self.content_types
super().add(f"repo:{repo} type:{','.join(content_types)}", **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "github_repo" in kwargs:
self.add(
repo=kwargs["github_repo"], content_types=kwargs.get("content_types")
)
super().add(
f"repo:{repo} type:{','.join(content_types)}",
data_type="github",
loader=self._loader,
)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
github_repo: Optional[str] = None,
content_types: Optional[List[str]] = None,
) -> str:
if github_repo:
self.add(
repo=github_repo,
content_types=content_types,
)
return super()._run(query=search_query)

View File

@@ -31,30 +31,16 @@ class JSONSearchTool(RagTool):
def __init__(self, json_path: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if json_path is not None:
kwargs["data_type"] = DataType.JSON
self.add(json_path)
self.description = f"A tool that can be used to semantic search a query the {json_path} JSON's content."
self.args_schema = FixedJSONSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "json_path" in kwargs:
self.add(kwargs["json_path"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
json_path: Optional[str] = None,
) -> str:
if json_path is not None:
self.add(json_path)
return super()._run(query=search_query)

View File

@@ -31,30 +31,19 @@ class MDXSearchTool(RagTool):
def __init__(self, mdx: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if mdx is not None:
kwargs["data_type"] = DataType.MDX
self.add(mdx)
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
self.args_schema = FixedMDXSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "mdx" in kwargs:
self.add(kwargs["mdx"])
def add(self, mdx: str) -> None:
super().add(mdx, data_type=DataType.MDX)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
mdx: Optional[str] = None,
) -> str:
if mdx is not None:
self.add(mdx)
return super()._run(query=search_query)

View File

@@ -30,39 +30,19 @@ class PDFSearchTool(RagTool):
def __init__(self, pdf: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if pdf is not None:
kwargs["data_type"] = DataType.PDF_FILE
self.add(pdf)
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
self.args_schema = FixedPDFSearchToolSchema
self._generate_description()
@model_validator(mode="after")
def _set_default_adapter(self):
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
from embedchain import App
def add(self, pdf: str) -> None:
super().add(pdf, data_type=DataType.PDF_FILE)
from crewai_tools.adapters.pdf_embedchain_adapter import (
PDFEmbedchainAdapter,
)
app = App.from_config(config=self.config) if self.config else App()
self.adapter = PDFEmbedchainAdapter(
embedchain_app=app, summarize=self.summarize
)
return self
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
def _run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "pdf" in kwargs:
self.add(kwargs["pdf"])
pdf: Optional[str] = None,
) -> str:
if pdf is not None:
self.add(pdf)
return super()._run(query=query)

View File

@@ -32,7 +32,9 @@ class PDFTextWritingTool(RagTool):
"""A tool to add text to specific positions in a PDF, with custom font support."""
name: str = "PDF Text Writing Tool"
description: str = "A tool that can write text to a specific position in a PDF document, with optional custom font embedding."
description: str = (
"A tool that can write text to a specific position in a PDF document, with optional custom font embedding."
)
args_schema: Type[BaseModel] = PDFTextWritingToolSchema
def run(
@@ -45,7 +47,6 @@ class PDFTextWritingTool(RagTool):
font_name: str = "F1",
font_file: Optional[str] = None,
page_number: int = 0,
**kwargs,
) -> str:
reader = PdfReader(pdf_path)
writer = PdfWriter()

View File

@@ -59,11 +59,5 @@ class RagTool(BaseTool):
def _run(
self,
query: str,
**kwargs: Any,
) -> Any:
self._before_run(query, **kwargs)
) -> str:
return f"Relevant Content:\n{self.adapter.query(query)}"
def _before_run(self, query, **kwargs):
pass

View File

@@ -41,14 +41,15 @@ class SerplyJobSearchTool(RagTool):
def _run(
self,
**kwargs: Any,
) -> Any:
query: Optional[str] = None,
search_query: Optional[str] = None,
) -> str:
query_payload = {}
if "query" in kwargs:
query_payload["q"] = kwargs["query"]
elif "search_query" in kwargs:
query_payload["q"] = kwargs["search_query"]
if query is not None:
query_payload["q"] = query
elif search_query is not None:
query_payload["q"] = search_query
# build the url
url = f"{self.request_url}{urlencode(query_payload)}"

View File

@@ -18,7 +18,9 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel):
class SerplyWebpageToMarkdownTool(RagTool):
name: str = "Webpage to Markdown"
description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
description: str = (
"A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
)
args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema
request_url: str = "https://api.serply.io/v1/request"
proxy_location: Optional[str] = "US"
@@ -39,9 +41,9 @@ class SerplyWebpageToMarkdownTool(RagTool):
def _run(
self,
**kwargs: Any,
) -> Any:
data = {"url": kwargs["url"], "method": "GET", "response_type": "markdown"}
url: str,
) -> str:
data = {"url": url, "method": "GET", "response_type": "markdown"}
response = requests.request(
"POST", self.request_url, headers=self.headers, json=data
)

View File

@@ -1,6 +1,5 @@
from typing import Any, Optional, Type
from typing import Optional, Type
from embedchain.models.data_type import DataType
from pydantic import BaseModel, Field
from ..rag.rag_tool import RagTool
@@ -31,30 +30,16 @@ class TXTSearchTool(RagTool):
def __init__(self, txt: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if txt is not None:
kwargs["data_type"] = DataType.TEXT_FILE
self.add(txt)
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
self.args_schema = FixedTXTSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "txt" in kwargs:
self.add(kwargs["txt"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
txt: Optional[str] = None,
) -> str:
if txt is not None:
self.add(txt)
return super()._run(query=search_query)

View File

@@ -25,36 +25,27 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
class WebsiteSearchTool(RagTool):
name: str = "Search in a specific website"
description: str = "A tool that can be used to semantic search a query from a specific URL content."
description: str = (
"A tool that can be used to semantic search a query from a specific URL content."
)
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
def __init__(self, website: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if website is not None:
kwargs["data_type"] = DataType.WEB_PAGE
self.add(website)
self.description = f"A tool that can be used to semantic search a query from {website} website content."
self.args_schema = FixedWebsiteSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "website" in kwargs:
self.add(kwargs["website"])
def add(self, website: str) -> None:
super().add(website, data_type=DataType.WEB_PAGE)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
website: Optional[str] = None,
) -> str:
if website is not None:
self.add(website)
return super()._run(query=search_query)

View File

@@ -31,30 +31,16 @@ class XMLSearchTool(RagTool):
def __init__(self, xml: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if xml is not None:
kwargs["data_type"] = DataType.XML
self.add(xml)
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
self.args_schema = FixedXMLSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "xml" in kwargs:
self.add(kwargs["xml"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
xml: Optional[str] = None,
) -> str:
if xml is not None:
self.add(xml)
return super()._run(query=search_query)

View File

@@ -25,13 +25,14 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
class YoutubeChannelSearchTool(RagTool):
name: str = "Search a Youtube Channels content"
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
description: str = (
"A tool that can be used to semantic search a query from a Youtube Channels content."
)
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if youtube_channel_handle is not None:
kwargs["data_type"] = DataType.YOUTUBE_CHANNEL
self.add(youtube_channel_handle)
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
self.args_schema = FixedYoutubeChannelSearchToolSchema
@@ -40,23 +41,16 @@ class YoutubeChannelSearchTool(RagTool):
def add(
self,
youtube_channel_handle: str,
**kwargs: Any,
) -> None:
if not youtube_channel_handle.startswith("@"):
youtube_channel_handle = f"@{youtube_channel_handle}"
super().add(youtube_channel_handle, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "youtube_channel_handle" in kwargs:
self.add(kwargs["youtube_channel_handle"])
super().add(youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
youtube_channel_handle: Optional[str] = None,
) -> str:
if youtube_channel_handle is not None:
self.add(youtube_channel_handle)
return super()._run(query=search_query)

View File

@@ -25,36 +25,27 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
class YoutubeVideoSearchTool(RagTool):
name: str = "Search a Youtube Video content"
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
description: str = (
"A tool that can be used to semantic search a query from a Youtube Video content."
)
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if youtube_video_url is not None:
kwargs["data_type"] = DataType.YOUTUBE_VIDEO
self.add(youtube_video_url)
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
self.args_schema = FixedYoutubeVideoSearchToolSchema
self._generate_description()
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "youtube_video_url" in kwargs:
self.add(kwargs["youtube_video_url"])
def add(self, youtube_video_url: str) -> None:
super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query, **kwargs)
youtube_video_url: Optional[str] = None,
) -> str:
if youtube_video_url is not None:
self.add(youtube_video_url)
return super()._run(query=search_query)