mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-13 01:58:30 +00:00
fix: Remove kwargs from all RagTools (#285)
* fix: remove kwargs from all (except mysql & pg) RagTools
The agent uses the tool description to decide what to propagate when a tool with **kwargs is found, but this often leads to failures during the tool invocation step.
This happens because the final description ends up like this:
```
CrewStructuredTool(name='Knowledge base', description='Tool Name: Knowledge base
Tool Arguments: {'query': {'description': None, 'type': 'str'}, 'kwargs': {'description': None, 'type': 'Any'}}
Tool Description: A knowledge base that can be used to answer questions.')
```
The agent then tries to infer and pass a kwargs parameter, which isn’t supported by the schema at all.
* feat: adding test to search tools
* feat: add db (chromadb folder) to .gitignore
* fix: fix github search integration
A few attributes were missing when calling the .add method: data_type and loader.
Also, update the query search according to the EmbedChain documentation, the query must include the type and repo keys
* fix: rollback YoutubeChannel paramenter
* chore: fix type hinting for CodeDocs search
* fix: ensure proper configuration when call `add`
According to the documentation, some search methods must be defined as either a loader or a data_type. This commit ensures that.
* build: add optional-dependencies for github and xml search
* test: mocking external requests from search_tool tests
* build: add pytest-recording as devDependencie
This commit is contained in:
@@ -31,30 +31,19 @@ class CodeDocsSearchTool(RagTool):
|
||||
def __init__(self, docs_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docs_url is not None:
|
||||
kwargs["data_type"] = DataType.DOCS_SITE
|
||||
self.add(docs_url)
|
||||
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
|
||||
self.args_schema = FixedCodeDocsSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "docs_url" in kwargs:
|
||||
self.add(kwargs["docs_url"])
|
||||
def add(self, docs_url: str) -> None:
|
||||
super().add(docs_url, data_type=DataType.DOCS_SITE)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
docs_url: Optional[str] = None,
|
||||
) -> str:
|
||||
if docs_url is not None:
|
||||
self.add(docs_url)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -31,30 +31,19 @@ class CSVSearchTool(RagTool):
|
||||
def __init__(self, csv: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if csv is not None:
|
||||
kwargs["data_type"] = DataType.CSV
|
||||
self.add(csv)
|
||||
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
|
||||
self.args_schema = FixedCSVSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "csv" in kwargs:
|
||||
self.add(kwargs["csv"])
|
||||
def add(self, csv: str) -> None:
|
||||
super().add(csv, data_type=DataType.CSV)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
csv: Optional[str] = None,
|
||||
) -> str:
|
||||
if csv is not None:
|
||||
self.add(csv)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -31,30 +31,22 @@ class DirectorySearchTool(RagTool):
|
||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
kwargs["loader"] = DirectoryLoader(config=dict(recursive=True))
|
||||
self.add(directory)
|
||||
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
|
||||
self.args_schema = FixedDirectorySearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "directory" in kwargs:
|
||||
self.add(kwargs["directory"])
|
||||
def add(self, directory: str) -> None:
|
||||
super().add(
|
||||
directory,
|
||||
loader=DirectoryLoader(config=dict(recursive=True)),
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
directory: Optional[str] = None,
|
||||
) -> str:
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -37,36 +37,19 @@ class DOCXSearchTool(RagTool):
|
||||
def __init__(self, docx: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docx is not None:
|
||||
kwargs["data_type"] = DataType.DOCX
|
||||
self.add(docx)
|
||||
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
|
||||
self.args_schema = FixedDOCXSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "docx" in kwargs:
|
||||
self.add(kwargs["docx"])
|
||||
def add(self, docx: str) -> None:
|
||||
super().add(docx, data_type=DataType.DOCX)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
search_query: str,
|
||||
docx: Optional[str] = None,
|
||||
) -> Any:
|
||||
search_query = kwargs.get("search_query")
|
||||
if search_query is None:
|
||||
search_query = kwargs.get("query")
|
||||
|
||||
docx = kwargs.get("docx")
|
||||
if docx is not None:
|
||||
self.add(docx)
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, List, Optional, Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from embedchain.loaders.github import GithubLoader
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
@@ -27,19 +27,29 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
|
||||
|
||||
class GithubSearchTool(RagTool):
|
||||
name: str = "Search a github repo's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
)
|
||||
summarize: bool = False
|
||||
gh_token: str
|
||||
args_schema: Type[BaseModel] = GithubSearchToolSchema
|
||||
content_types: List[str]
|
||||
content_types: List[str] = Field(
|
||||
default_factory=lambda: ["code", "repo", "pr", "issue"],
|
||||
description="Content types you want to be included search, options: [code, repo, pr, issue]",
|
||||
)
|
||||
_loader: GithubLoader | None = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, github_repo: Optional[str] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
github_repo: Optional[str] = None,
|
||||
content_types: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if github_repo is not None:
|
||||
kwargs["data_type"] = "github"
|
||||
kwargs["loader"] = GithubLoader(config={"token": self.gh_token})
|
||||
self._loader = GithubLoader(config={"token": self.gh_token})
|
||||
|
||||
self.add(repo=github_repo)
|
||||
if github_repo and content_types:
|
||||
self.add(repo=github_repo, content_types=content_types)
|
||||
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
self.args_schema = FixedGithubSearchToolSchema
|
||||
self._generate_description()
|
||||
@@ -47,26 +57,25 @@ class GithubSearchTool(RagTool):
|
||||
def add(
|
||||
self,
|
||||
repo: str,
|
||||
content_types: List[str] | None = None,
|
||||
**kwargs: Any,
|
||||
content_types: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
content_types = content_types or self.content_types
|
||||
|
||||
super().add(f"repo:{repo} type:{','.join(content_types)}", **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "github_repo" in kwargs:
|
||||
self.add(
|
||||
repo=kwargs["github_repo"], content_types=kwargs.get("content_types")
|
||||
)
|
||||
super().add(
|
||||
f"repo:{repo} type:{','.join(content_types)}",
|
||||
data_type="github",
|
||||
loader=self._loader,
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
github_repo: Optional[str] = None,
|
||||
content_types: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
if github_repo:
|
||||
self.add(
|
||||
repo=github_repo,
|
||||
content_types=content_types,
|
||||
)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -31,30 +31,16 @@ class JSONSearchTool(RagTool):
|
||||
def __init__(self, json_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if json_path is not None:
|
||||
kwargs["data_type"] = DataType.JSON
|
||||
self.add(json_path)
|
||||
self.description = f"A tool that can be used to semantic search a query the {json_path} JSON's content."
|
||||
self.args_schema = FixedJSONSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "json_path" in kwargs:
|
||||
self.add(kwargs["json_path"])
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
json_path: Optional[str] = None,
|
||||
) -> str:
|
||||
if json_path is not None:
|
||||
self.add(json_path)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -31,30 +31,19 @@ class MDXSearchTool(RagTool):
|
||||
def __init__(self, mdx: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if mdx is not None:
|
||||
kwargs["data_type"] = DataType.MDX
|
||||
self.add(mdx)
|
||||
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
|
||||
self.args_schema = FixedMDXSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "mdx" in kwargs:
|
||||
self.add(kwargs["mdx"])
|
||||
def add(self, mdx: str) -> None:
|
||||
super().add(mdx, data_type=DataType.MDX)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
mdx: Optional[str] = None,
|
||||
) -> str:
|
||||
if mdx is not None:
|
||||
self.add(mdx)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -30,39 +30,19 @@ class PDFSearchTool(RagTool):
|
||||
def __init__(self, pdf: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if pdf is not None:
|
||||
kwargs["data_type"] = DataType.PDF_FILE
|
||||
self.add(pdf)
|
||||
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
|
||||
self.args_schema = FixedPDFSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_default_adapter(self):
|
||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||
from embedchain import App
|
||||
def add(self, pdf: str) -> None:
|
||||
super().add(pdf, data_type=DataType.PDF_FILE)
|
||||
|
||||
from crewai_tools.adapters.pdf_embedchain_adapter import (
|
||||
PDFEmbedchainAdapter,
|
||||
)
|
||||
|
||||
app = App.from_config(config=self.config) if self.config else App()
|
||||
self.adapter = PDFEmbedchainAdapter(
|
||||
embedchain_app=app, summarize=self.summarize
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "pdf" in kwargs:
|
||||
self.add(kwargs["pdf"])
|
||||
pdf: Optional[str] = None,
|
||||
) -> str:
|
||||
if pdf is not None:
|
||||
self.add(pdf)
|
||||
return super()._run(query=query)
|
||||
|
||||
@@ -32,7 +32,9 @@ class PDFTextWritingTool(RagTool):
|
||||
"""A tool to add text to specific positions in a PDF, with custom font support."""
|
||||
|
||||
name: str = "PDF Text Writing Tool"
|
||||
description: str = "A tool that can write text to a specific position in a PDF document, with optional custom font embedding."
|
||||
description: str = (
|
||||
"A tool that can write text to a specific position in a PDF document, with optional custom font embedding."
|
||||
)
|
||||
args_schema: Type[BaseModel] = PDFTextWritingToolSchema
|
||||
|
||||
def run(
|
||||
@@ -45,7 +47,6 @@ class PDFTextWritingTool(RagTool):
|
||||
font_name: str = "F1",
|
||||
font_file: Optional[str] = None,
|
||||
page_number: int = 0,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
reader = PdfReader(pdf_path)
|
||||
writer = PdfWriter()
|
||||
|
||||
@@ -59,11 +59,5 @@ class RagTool(BaseTool):
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self._before_run(query, **kwargs)
|
||||
|
||||
) -> str:
|
||||
return f"Relevant Content:\n{self.adapter.query(query)}"
|
||||
|
||||
def _before_run(self, query, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -41,14 +41,15 @@ class SerplyJobSearchTool(RagTool):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
query: Optional[str] = None,
|
||||
search_query: Optional[str] = None,
|
||||
) -> str:
|
||||
query_payload = {}
|
||||
|
||||
if "query" in kwargs:
|
||||
query_payload["q"] = kwargs["query"]
|
||||
elif "search_query" in kwargs:
|
||||
query_payload["q"] = kwargs["search_query"]
|
||||
if query is not None:
|
||||
query_payload["q"] = query
|
||||
elif search_query is not None:
|
||||
query_payload["q"] = search_query
|
||||
|
||||
# build the url
|
||||
url = f"{self.request_url}{urlencode(query_payload)}"
|
||||
|
||||
@@ -18,7 +18,9 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel):
|
||||
|
||||
class SerplyWebpageToMarkdownTool(RagTool):
|
||||
name: str = "Webpage to Markdown"
|
||||
description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
|
||||
description: str = (
|
||||
"A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
|
||||
)
|
||||
args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema
|
||||
request_url: str = "https://api.serply.io/v1/request"
|
||||
proxy_location: Optional[str] = "US"
|
||||
@@ -39,9 +41,9 @@ class SerplyWebpageToMarkdownTool(RagTool):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
data = {"url": kwargs["url"], "method": "GET", "response_type": "markdown"}
|
||||
url: str,
|
||||
) -> str:
|
||||
data = {"url": url, "method": "GET", "response_type": "markdown"}
|
||||
response = requests.request(
|
||||
"POST", self.request_url, headers=self.headers, json=data
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
@@ -31,30 +30,16 @@ class TXTSearchTool(RagTool):
|
||||
def __init__(self, txt: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if txt is not None:
|
||||
kwargs["data_type"] = DataType.TEXT_FILE
|
||||
self.add(txt)
|
||||
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
|
||||
self.args_schema = FixedTXTSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "txt" in kwargs:
|
||||
self.add(kwargs["txt"])
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
txt: Optional[str] = None,
|
||||
) -> str:
|
||||
if txt is not None:
|
||||
self.add(txt)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -25,36 +25,27 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
|
||||
|
||||
class WebsiteSearchTool(RagTool):
|
||||
name: str = "Search in a specific website"
|
||||
description: str = "A tool that can be used to semantic search a query from a specific URL content."
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a specific URL content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
|
||||
|
||||
def __init__(self, website: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if website is not None:
|
||||
kwargs["data_type"] = DataType.WEB_PAGE
|
||||
self.add(website)
|
||||
self.description = f"A tool that can be used to semantic search a query from {website} website content."
|
||||
self.args_schema = FixedWebsiteSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "website" in kwargs:
|
||||
self.add(kwargs["website"])
|
||||
def add(self, website: str) -> None:
|
||||
super().add(website, data_type=DataType.WEB_PAGE)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
website: Optional[str] = None,
|
||||
) -> str:
|
||||
if website is not None:
|
||||
self.add(website)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -31,30 +31,16 @@ class XMLSearchTool(RagTool):
|
||||
def __init__(self, xml: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if xml is not None:
|
||||
kwargs["data_type"] = DataType.XML
|
||||
self.add(xml)
|
||||
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
|
||||
self.args_schema = FixedXMLSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "xml" in kwargs:
|
||||
self.add(kwargs["xml"])
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
xml: Optional[str] = None,
|
||||
) -> str:
|
||||
if xml is not None:
|
||||
self.add(xml)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -25,13 +25,14 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
|
||||
|
||||
class YoutubeChannelSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Channels content"
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if youtube_channel_handle is not None:
|
||||
kwargs["data_type"] = DataType.YOUTUBE_CHANNEL
|
||||
self.add(youtube_channel_handle)
|
||||
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
|
||||
self.args_schema = FixedYoutubeChannelSearchToolSchema
|
||||
@@ -40,23 +41,16 @@ class YoutubeChannelSearchTool(RagTool):
|
||||
def add(
|
||||
self,
|
||||
youtube_channel_handle: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if not youtube_channel_handle.startswith("@"):
|
||||
youtube_channel_handle = f"@{youtube_channel_handle}"
|
||||
super().add(youtube_channel_handle, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "youtube_channel_handle" in kwargs:
|
||||
self.add(kwargs["youtube_channel_handle"])
|
||||
super().add(youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
youtube_channel_handle: Optional[str] = None,
|
||||
) -> str:
|
||||
if youtube_channel_handle is not None:
|
||||
self.add(youtube_channel_handle)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
@@ -25,36 +25,27 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
|
||||
|
||||
class YoutubeVideoSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Video content"
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if youtube_video_url is not None:
|
||||
kwargs["data_type"] = DataType.YOUTUBE_VIDEO
|
||||
self.add(youtube_video_url)
|
||||
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
|
||||
self.args_schema = FixedYoutubeVideoSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "youtube_video_url" in kwargs:
|
||||
self.add(kwargs["youtube_video_url"])
|
||||
def add(self, youtube_video_url: str) -> None:
|
||||
super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
youtube_video_url: Optional[str] = None,
|
||||
) -> str:
|
||||
if youtube_video_url is not None:
|
||||
self.add(youtube_video_url)
|
||||
return super()._run(query=search_query)
|
||||
|
||||
Reference in New Issue
Block a user