Merge pull request #21 from joaomdmoura/gui/fix-rag-tools

Fix RAG tools
This commit is contained in:
João Moura
2024-04-04 13:45:28 -03:00
committed by GitHub
13 changed files with 92 additions and 1 deletions

View File

@@ -50,3 +50,10 @@ class CodeDocsSearchTool(RagTool):
) -> Any: ) -> Any:
if "docs_url" in kwargs: if "docs_url" in kwargs:
self.add(kwargs["docs_url"]) self.add(kwargs["docs_url"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -50,3 +50,10 @@ class CSVSearchTool(RagTool):
) -> Any: ) -> Any:
if "csv" in kwargs: if "csv" in kwargs:
self.add(kwargs["csv"]) self.add(kwargs["csv"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -50,3 +50,10 @@ class DirectorySearchTool(RagTool):
) -> Any: ) -> Any:
if "directory" in kwargs: if "directory" in kwargs:
self.add(kwargs["directory"]) self.add(kwargs["directory"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -50,3 +50,10 @@ class DOCXSearchTool(RagTool):
) -> Any: ) -> Any:
if "docx" in kwargs: if "docx" in kwargs:
self.add(kwargs["docx"]) self.add(kwargs["docx"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -21,7 +21,7 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
github_repo: str = Field(..., description="Mandatory github you want to search") github_repo: str = Field(..., description="Mandatory github you want to search")
content_types: List[str] = Field( content_types: List[str] = Field(
..., ...,
description="Mandatory content types you want to be inlcuded search, options: [code, repo, pr, issue]", description="Mandatory content types you want to be included search, options: [code, repo, pr, issue]",
) )
@@ -56,3 +56,10 @@ class GithubSearchTool(RagTool):
) -> Any: ) -> Any:
if "github_repo" in kwargs: if "github_repo" in kwargs:
self.add(kwargs["github_repo"]) self.add(kwargs["github_repo"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -50,3 +50,10 @@ class JSONSearchTool(RagTool):
) -> Any: ) -> Any:
if "json_path" in kwargs: if "json_path" in kwargs:
self.add(kwargs["json_path"]) self.add(kwargs["json_path"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -50,3 +50,10 @@ class MDXSearchTool(RagTool):
) -> Any: ) -> Any:
if "mdx" in kwargs: if "mdx" in kwargs:
self.add(kwargs["mdx"]) self.add(kwargs["mdx"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -35,3 +35,10 @@ class PGSearchTool(RagTool):
kwargs["data_type"] = "postgres" kwargs["data_type"] = "postgres"
kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri)) kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
super().add(f"SELECT * FROM {table_name};", **kwargs) super().add(f"SELECT * FROM {table_name};", **kwargs)
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -50,3 +50,10 @@ class TXTSearchTool(RagTool):
) -> Any: ) -> Any:
if "txt" in kwargs: if "txt" in kwargs:
self.add(kwargs["txt"]) self.add(kwargs["txt"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -50,3 +50,10 @@ class WebsiteSearchTool(RagTool):
) -> Any: ) -> Any:
if "website" in kwargs: if "website" in kwargs:
self.add(kwargs["website"]) self.add(kwargs["website"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -50,3 +50,10 @@ class XMLSearchTool(RagTool):
) -> Any: ) -> Any:
if "xml" in kwargs: if "xml" in kwargs:
self.add(kwargs["xml"]) self.add(kwargs["xml"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -53,3 +53,10 @@ class YoutubeChannelSearchTool(RagTool):
) -> Any: ) -> Any:
if "youtube_channel_handle" in kwargs: if "youtube_channel_handle" in kwargs:
self.add(kwargs["youtube_channel_handle"]) self.add(kwargs["youtube_channel_handle"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)

View File

@@ -50,3 +50,10 @@ class YoutubeVideoSearchTool(RagTool):
) -> Any: ) -> Any:
if "youtube_video_url" in kwargs: if "youtube_video_url" in kwargs:
self.add(kwargs["youtube_video_url"]) self.add(kwargs["youtube_video_url"])
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
return super()._run(query=search_query)