Fix GithubSearchTool

This commit is contained in:
Gui Vieira
2024-04-05 18:04:45 -03:00
parent 9c98ad455d
commit 776826ec99

View File

@@ -36,18 +36,21 @@ class GithubSearchTool(RagTool):
def __init__(self, github_repo: Optional[str] = None, **kwargs): def __init__(self, github_repo: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if github_repo is not None: if github_repo is not None:
self.add(github_repo) self.add(repo=github_repo)
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content." self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content."
self.args_schema = FixedGithubSearchToolSchema self.args_schema = FixedGithubSearchToolSchema
def add( def add(
self, self,
*args: Any, repo: str,
content_types: List[str] | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
content_types = content_types or self.content_types
kwargs["data_type"] = "github" kwargs["data_type"] = "github"
kwargs["loader"] = GithubLoader(config={"token": self.gh_token}) kwargs["loader"] = GithubLoader(config={"token": self.gh_token})
super().add(*args, **kwargs) super().add(f"repo:{repo} type:{','.join(content_types)}", **kwargs)
def _before_run( def _before_run(
self, self,
@@ -55,7 +58,9 @@ class GithubSearchTool(RagTool):
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
if "github_repo" in kwargs: if "github_repo" in kwargs:
self.add(kwargs["github_repo"]) self.add(
repo=kwargs["github_repo"], content_types=kwargs.get("content_types")
)
def _run( def _run(
self, self,