From 1c8d010601512cd2410515cdf78f95fceb2c0c33 Mon Sep 17 00:00:00 2001 From: Gui Vieira Date: Tue, 19 Mar 2024 18:47:13 -0300 Subject: [PATCH] Custom model config for RAG tools --- .../adapters/embedchain_adapter.py | 17 +++- src/crewai_tools/adapters/lancedb_adapter.py | 9 ++- src/crewai_tools/tools/base_tool.py | 69 ++++++++++------ .../code_docs_search_tool.py | 67 +++++++++------- .../tools/csv_search_tool/csv_search_tool.py | 67 +++++++++------- .../directory_search_tool.py | 68 +++++++++------- .../docx_search_tool/docx_search_tool.py | 67 +++++++++------- .../github_search_tool/github_search_tool.py | 78 +++++++++++-------- .../json_search_tool/json_search_tool.py | 67 +++++++++------- .../tools/mdx_seach_tool/mdx_search_tool.py | 67 +++++++++------- .../tools/pdf_search_tool/pdf_search_tool.py | 66 +++++++++------- .../tools/pg_seach_tool/pg_search_tool.py | 62 +++++++-------- src/crewai_tools/tools/rag/README.md | 5 +- src/crewai_tools/tools/rag/rag_tool.py | 63 +++++++++++---- .../tools/txt_search_tool/txt_search_tool.py | 68 +++++++++------- .../website_search/website_search_tool.py | 67 +++++++++------- .../tools/xml_search_tool/xml_search_tool.py | 67 +++++++++------- .../youtube_channel_search_tool.py | 72 ++++++++++------- .../youtube_video_search_tool.py | 67 +++++++++------- tests/tools/rag/rag_tool_test.py | 43 ++++++++++ 20 files changed, 704 insertions(+), 452 deletions(-) create mode 100644 tests/tools/rag/rag_tool_test.py diff --git a/src/crewai_tools/adapters/embedchain_adapter.py b/src/crewai_tools/adapters/embedchain_adapter.py index 16491fb25..446aab96c 100644 --- a/src/crewai_tools/adapters/embedchain_adapter.py +++ b/src/crewai_tools/adapters/embedchain_adapter.py @@ -1,12 +1,25 @@ from typing import Any + +from embedchain import App + from crewai_tools.tools.rag.rag_tool import Adapter + class EmbedchainAdapter(Adapter): - embedchain_app: Any + embedchain_app: App summarize: bool = False def query(self, question: str) -> str: - result, sources = self.embedchain_app.query(question, citations=True, dry_run=(not self.summarize)) + result, sources = self.embedchain_app.query( + question, citations=True, dry_run=(not self.summarize) + ) if self.summarize: return result return "\n\n".join([source[0] for source in sources]) + + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.embedchain_app.add(*args, **kwargs) diff --git a/src/crewai_tools/adapters/lancedb_adapter.py b/src/crewai_tools/adapters/lancedb_adapter.py index c612d475c..c91423048 100644 --- a/src/crewai_tools/adapters/lancedb_adapter.py +++ b/src/crewai_tools/adapters/lancedb_adapter.py @@ -35,7 +35,7 @@ class LanceDBAdapter(Adapter): self._db = lancedb_connect(self.uri) self._table = self._db.open_table(self.table_name) - return super().model_post_init(__context) + super().model_post_init(__context) def query(self, question: str) -> str: query = self.embedding_function([question])[0] @@ -47,3 +47,10 @@ class LanceDBAdapter(Adapter): ) values = [result[self.text_column_name] for result in results] return "\n".join(values) + + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self._table.add(*args, **kwargs) diff --git a/src/crewai_tools/tools/base_tool.py b/src/crewai_tools/tools/base_tool.py index 545529bdd..961688629 100644 --- a/src/crewai_tools/tools/base_tool.py +++ b/src/crewai_tools/tools/base_tool.py @@ -1,28 +1,47 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Optional, Type -from pydantic import BaseModel, model_validator +from langchain_core.tools import StructuredTool +from pydantic import BaseModel, ConfigDict, Field, validator from pydantic.v1 import BaseModel as V1BaseModel -from langchain_core.tools import StructuredTool class BaseTool(BaseModel, ABC): + class _ArgsSchemaPlaceholder(V1BaseModel): + pass + + model_config = ConfigDict() + name: str """The unique name of the tool that clearly communicates its purpose.""" description: str """Used to tell the model how/when/why to use the tool.""" - args_schema: Optional[Type[V1BaseModel]] = None + args_schema: Type[V1BaseModel] = Field(default_factory=_ArgsSchemaPlaceholder) """The schema for the arguments that the tool accepts.""" description_updated: bool = False """Flag to check if the description has been updated.""" cache_function: Optional[Callable] = lambda: True """Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.""" - @model_validator(mode="after") - def _check_args_schema(self): - self._set_args_schema() + @validator("args_schema", always=True, pre=True) + def _default_args_schema(cls, v: Type[V1BaseModel]) -> Type[V1BaseModel]: + if not isinstance(v, cls._ArgsSchemaPlaceholder): + return v + + return type( + f"{cls.__name__}Schema", + (V1BaseModel,), + { + "__annotations__": { + k: v for k, v in cls._run.__annotations__.items() if k != "return" + }, + }, + ) + + def model_post_init(self, __context: Any) -> None: self._generate_description() - return self + + super().model_post_init(__context) def run( self, @@ -57,16 +76,20 @@ class BaseTool(BaseModel, ABC): (V1BaseModel,), { "__annotations__": { - k: v for k, v in self._run.__annotations__.items() if k != 'return' + k: v + for k, v in self._run.__annotations__.items() + if k != "return" }, }, ) + def _generate_description(self): args = [] - for arg, attribute in self.args_schema.schema()['properties'].items(): - args.append(f"{arg}: '{attribute['type']}'") + for arg, attribute in self.args_schema.schema()["properties"].items(): + if "type" in attribute: + args.append(f"{arg}: '{attribute['type']}'") - description = self.description.replace('\n', ' ') + description = self.description.replace("\n", " ") self.description = f"{self.name}({', '.join(args)}) - {description}" @@ -93,19 +116,19 @@ def tool(*args): def _make_tool(f: Callable) -> BaseTool: if f.__doc__ is None: raise ValueError("Function must have a docstring") + if f.__annotations__ is None: + raise ValueError("Function must have type annotations") - args_schema = None - if f.__annotations__: - class_name = "".join(tool_name.split()).title() - args_schema = type( - class_name, - (V1BaseModel,), - { - "__annotations__": { - k: v for k, v in f.__annotations__.items() if k != 'return' - }, + class_name = "".join(tool_name.split()).title() + args_schema = type( + class_name, + (V1BaseModel,), + { + "__annotations__": { + k: v for k, v in f.__annotations__.items() if k != "return" }, - ) + }, + ) return Tool( name=tool_name, @@ -120,4 +143,4 @@ def tool(*args): return _make_with_name(args[0].__name__)(args[0]) if len(args) == 1 and isinstance(args[0], str): return _make_with_name(args[0]) - raise ValueError("Invalid arguments") \ No newline at end of file + raise ValueError("Invalid arguments") diff --git a/src/crewai_tools/tools/code_docs_search_tool/code_docs_search_tool.py b/src/crewai_tools/tools/code_docs_search_tool/code_docs_search_tool.py index 54ba69d01..195cc8a05 100644 --- a/src/crewai_tools/tools/code_docs_search_tool/code_docs_search_tool.py +++ b/src/crewai_tools/tools/code_docs_search_tool/code_docs_search_tool.py @@ -1,41 +1,52 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.models.data_type import DataType +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedCodeDocsSearchToolSchema(BaseModel): - """Input for CodeDocsSearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search the Code Docs content") + """Input for CodeDocsSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search the Code Docs content", + ) + class CodeDocsSearchToolSchema(FixedCodeDocsSearchToolSchema): - """Input for CodeDocsSearchTool.""" - docs_url: str = Field(..., description="Mandatory docs_url path you want to search") + """Input for CodeDocsSearchTool.""" + + docs_url: str = Field(..., description="Mandatory docs_url path you want to search") + class CodeDocsSearchTool(RagTool): - name: str = "Search a Code Docs content" - description: str = "A tool that can be used to semantic search a query from a Code Docs content." - summarize: bool = False - args_schema: Type[BaseModel] = CodeDocsSearchToolSchema - docs_url: Optional[str] = None + name: str = "Search a Code Docs content" + description: str = ( + "A tool that can be used to semantic search a query from a Code Docs content." + ) + args_schema: Type[BaseModel] = CodeDocsSearchToolSchema - def __init__(self, docs_url: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if docs_url is not None: - self.docs_url = docs_url - self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content." - self.args_schema = FixedCodeDocsSearchToolSchema - self._generate_description() + def __init__(self, docs_url: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if docs_url is not None: + self.add(docs_url) + self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content." + self.args_schema = FixedCodeDocsSearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - docs_url = kwargs.get('docs_url', self.docs_url) - self.app = App() - self.app.add(docs_url, data_type=DataType.DOCS_SITE) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = DataType.DOCS_SITE + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "docs_url" in kwargs: + self.add(kwargs["docs_url"]) diff --git a/src/crewai_tools/tools/csv_search_tool/csv_search_tool.py b/src/crewai_tools/tools/csv_search_tool/csv_search_tool.py index cd99ebfd2..6b8e79f88 100644 --- a/src/crewai_tools/tools/csv_search_tool/csv_search_tool.py +++ b/src/crewai_tools/tools/csv_search_tool/csv_search_tool.py @@ -1,41 +1,52 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.models.data_type import DataType +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedCSVSearchToolSchema(BaseModel): - """Input for CSVSearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search the CSV's content") + """Input for CSVSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search the CSV's content", + ) + class CSVSearchToolSchema(FixedCSVSearchToolSchema): - """Input for CSVSearchTool.""" - csv: str = Field(..., description="Mandatory csv path you want to search") + """Input for CSVSearchTool.""" + + csv: str = Field(..., description="Mandatory csv path you want to search") + class CSVSearchTool(RagTool): - name: str = "Search a CSV's content" - description: str = "A tool that can be used to semantic search a query from a CSV's content." - summarize: bool = False - args_schema: Type[BaseModel] = CSVSearchToolSchema - csv: Optional[str] = None + name: str = "Search a CSV's content" + description: str = ( + "A tool that can be used to semantic search a query from a CSV's content." + ) + args_schema: Type[BaseModel] = CSVSearchToolSchema - def __init__(self, csv: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if csv is not None: - self.csv = csv - self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content." - self.args_schema = FixedCSVSearchToolSchema - self._generate_description() + def __init__(self, csv: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if csv is not None: + self.add(csv) + self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content." + self.args_schema = FixedCSVSearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - csv = kwargs.get('csv', self.csv) - self.app = App() - self.app.add(csv, data_type=DataType.CSV) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = DataType.CSV + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "csv" in kwargs: + self.add(kwargs["csv"]) diff --git a/src/crewai_tools/tools/directory_search_tool/directory_search_tool.py b/src/crewai_tools/tools/directory_search_tool/directory_search_tool.py index 2cd888a8b..7f20f5979 100644 --- a/src/crewai_tools/tools/directory_search_tool/directory_search_tool.py +++ b/src/crewai_tools/tools/directory_search_tool/directory_search_tool.py @@ -1,42 +1,52 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.loaders.directory_loader import DirectoryLoader +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedDirectorySearchToolSchema(BaseModel): - """Input for DirectorySearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search the directory's content") + """Input for DirectorySearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search the directory's content", + ) + class DirectorySearchToolSchema(FixedDirectorySearchToolSchema): - """Input for DirectorySearchTool.""" - directory: str = Field(..., description="Mandatory directory you want to search") + """Input for DirectorySearchTool.""" + + directory: str = Field(..., description="Mandatory directory you want to search") + class DirectorySearchTool(RagTool): - name: str = "Search a directory's content" - description: str = "A tool that can be used to semantic search a query from a directory's content." - summarize: bool = False - args_schema: Type[BaseModel] = DirectorySearchToolSchema - directory: Optional[str] = None + name: str = "Search a directory's content" + description: str = ( + "A tool that can be used to semantic search a query from a directory's content." + ) + args_schema: Type[BaseModel] = DirectorySearchToolSchema - def __init__(self, directory: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if directory is not None: - self.directory = directory - self.description = f"A tool that can be used to semantic search a query the {directory} directory's content." - self.args_schema = FixedDirectorySearchToolSchema - self._generate_description() + def __init__(self, directory: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if directory is not None: + self.add(directory) + self.description = f"A tool that can be used to semantic search a query the {directory} directory's content." + self.args_schema = FixedDirectorySearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - directory = kwargs.get('directory', self.directory) - loader = DirectoryLoader(config=dict(recursive=True)) - self.app = App() - self.app.add(directory, loader=loader) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["loader"] = DirectoryLoader(config=dict(recursive=True)) + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "directory" in kwargs: + self.add(kwargs["directory"]) diff --git a/src/crewai_tools/tools/docx_search_tool/docx_search_tool.py b/src/crewai_tools/tools/docx_search_tool/docx_search_tool.py index 135837a6b..5c64f9824 100644 --- a/src/crewai_tools/tools/docx_search_tool/docx_search_tool.py +++ b/src/crewai_tools/tools/docx_search_tool/docx_search_tool.py @@ -1,41 +1,52 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.models.data_type import DataType +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedDOCXSearchToolSchema(BaseModel): - """Input for DOCXSearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search the DOCX's content") + """Input for DOCXSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search the DOCX's content", + ) + class DOCXSearchToolSchema(FixedDOCXSearchToolSchema): - """Input for DOCXSearchTool.""" - docx: str = Field(..., description="Mandatory docx path you want to search") + """Input for DOCXSearchTool.""" + + docx: str = Field(..., description="Mandatory docx path you want to search") + class DOCXSearchTool(RagTool): - name: str = "Search a DOCX's content" - description: str = "A tool that can be used to semantic search a query from a DOCX's content." - summarize: bool = False - args_schema: Type[BaseModel] = DOCXSearchToolSchema - docx: Optional[str] = None + name: str = "Search a DOCX's content" + description: str = ( + "A tool that can be used to semantic search a query from a DOCX's content." + ) + args_schema: Type[BaseModel] = DOCXSearchToolSchema - def __init__(self, docx: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if docx is not None: - self.docx = docx - self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content." - self.args_schema = FixedDOCXSearchToolSchema - self._generate_description() + def __init__(self, docx: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if docx is not None: + self.add(docx) + self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content." + self.args_schema = FixedDOCXSearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - docx = kwargs.get('docx', self.docx) - self.app = App() - self.app.add(docx, data_type=DataType.DOCX) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = DataType.DOCX + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "docx" in kwargs: + self.add(kwargs["docx"]) diff --git a/src/crewai_tools/tools/github_search_tool/github_search_tool.py b/src/crewai_tools/tools/github_search_tool/github_search_tool.py index cb2815aad..4a84b166c 100644 --- a/src/crewai_tools/tools/github_search_tool/github_search_tool.py +++ b/src/crewai_tools/tools/github_search_tool/github_search_tool.py @@ -1,46 +1,58 @@ -from typing import Optional, Type, List, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, List, Optional, Type -from embedchain import App from embedchain.loaders.github import GithubLoader +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedGithubSearchToolSchema(BaseModel): - """Input for GithubSearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search the github repo's content") + """Input for GithubSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search the github repo's content", + ) + class GithubSearchToolSchema(FixedGithubSearchToolSchema): - """Input for GithubSearchTool.""" - github_repo: str = Field(..., description="Mandatory github you want to search") - content_types: List[str] = Field(..., description="Mandatory content types you want to be inlcuded search, options: [code, repo, pr, issue]") + """Input for GithubSearchTool.""" + + github_repo: str = Field(..., description="Mandatory github you want to search") + content_types: List[str] = Field( + ..., + description="Mandatory content types you want to be inlcuded search, options: [code, repo, pr, issue]", + ) + class GithubSearchTool(RagTool): - name: str = "Search a github repo's content" - description: str = "A tool that can be used to semantic search a query from a github repo's content." - summarize: bool = False - gh_token: str = None - args_schema: Type[BaseModel] = GithubSearchToolSchema - github_repo: Optional[str] = None - content_types: List[str] + name: str = "Search a github repo's content" + description: str = "A tool that can be used to semantic search a query from a github repo's content." + summarize: bool = False + gh_token: str + args_schema: Type[BaseModel] = GithubSearchToolSchema + content_types: List[str] - def __init__(self, github_repo: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if github_repo is not None: - self.github_repo = github_repo - self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content." - self.args_schema = FixedGithubSearchToolSchema - self._generate_description() + def __init__(self, github_repo: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if github_repo is not None: + self.add(github_repo) + self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content." + self.args_schema = FixedGithubSearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - github_repo = kwargs.get('github_repo', self.github_repo) - loader = GithubLoader(config={"token": self.gh_token}) - app = App() - app.add(f"repo:{github_repo} type:{','.join(self.content_types)}", data_type="github", loader=loader) - self.app = app - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = "github" + kwargs["loader"] = GithubLoader(config={"token": self.gh_token}) + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "github_repo" in kwargs: + self.add(kwargs["github_repo"]) diff --git a/src/crewai_tools/tools/json_search_tool/json_search_tool.py b/src/crewai_tools/tools/json_search_tool/json_search_tool.py index 578f06bc9..308dca726 100644 --- a/src/crewai_tools/tools/json_search_tool/json_search_tool.py +++ b/src/crewai_tools/tools/json_search_tool/json_search_tool.py @@ -1,41 +1,52 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.models.data_type import DataType +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedJSONSearchToolSchema(BaseModel): - """Input for JSONSearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search the JSON's content") + """Input for JSONSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search the JSON's content", + ) + class JSONSearchToolSchema(FixedJSONSearchToolSchema): - """Input for JSONSearchTool.""" - json_path: str = Field(..., description="Mandatory json path you want to search") + """Input for JSONSearchTool.""" + + json_path: str = Field(..., description="Mandatory json path you want to search") + class JSONSearchTool(RagTool): - name: str = "Search a JSON's content" - description: str = "A tool that can be used to semantic search a query from a JSON's content." - summarize: bool = False - args_schema: Type[BaseModel] = JSONSearchToolSchema - json_path: Optional[str] = None + name: str = "Search a JSON's content" + description: str = ( + "A tool that can be used to semantic search a query from a JSON's content." + ) + args_schema: Type[BaseModel] = JSONSearchToolSchema - def __init__(self, json_path: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if json_path is not None: - self.json_path = json_path - self.description = f"A tool that can be used to semantic search a query the {json} JSON's content." - self.args_schema = FixedJSONSearchToolSchema - self._generate_description() + def __init__(self, json_path: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if json_path is not None: + self.add(json_path) + self.description = f"A tool that can be used to semantic search a query the {json_path} JSON's content." + self.args_schema = FixedJSONSearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - json_path = kwargs.get('json_path', self.json_path) - self.app = App() - self.app.add(json_path, data_type=DataType.JSON) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = DataType.JSON + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "json_path" in kwargs: + self.add(kwargs["json_path"]) diff --git a/src/crewai_tools/tools/mdx_seach_tool/mdx_search_tool.py b/src/crewai_tools/tools/mdx_seach_tool/mdx_search_tool.py index e34c0fa08..33a58e142 100644 --- a/src/crewai_tools/tools/mdx_seach_tool/mdx_search_tool.py +++ b/src/crewai_tools/tools/mdx_seach_tool/mdx_search_tool.py @@ -1,41 +1,52 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.models.data_type import DataType +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedMDXSearchToolSchema(BaseModel): - """Input for MDXSearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search the MDX's content") + """Input for MDXSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search the MDX's content", + ) + class MDXSearchToolSchema(FixedMDXSearchToolSchema): - """Input for MDXSearchTool.""" - mdx: str = Field(..., description="Mandatory mdx path you want to search") + """Input for MDXSearchTool.""" + + mdx: str = Field(..., description="Mandatory mdx path you want to search") + class MDXSearchTool(RagTool): - name: str = "Search a MDX's content" - description: str = "A tool that can be used to semantic search a query from a MDX's content." - summarize: bool = False - args_schema: Type[BaseModel] = MDXSearchToolSchema - mdx: Optional[str] = None + name: str = "Search a MDX's content" + description: str = ( + "A tool that can be used to semantic search a query from a MDX's content." + ) + args_schema: Type[BaseModel] = MDXSearchToolSchema - def __init__(self, mdx: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if mdx is not None: - self.mdx = mdx - self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content." - self.args_schema = FixedMDXSearchToolSchema - self._generate_description() + def __init__(self, mdx: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if mdx is not None: + self.add(mdx) + self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content." + self.args_schema = FixedMDXSearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - mdx = kwargs.get('mdx', self.mdx) - self.app = App() - self.app.add(mdx, data_type=DataType.MDX) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = DataType.MDX + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "mdx" in kwargs: + self.add(kwargs["mdx"]) diff --git a/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py b/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py index e75cb8610..47e425a45 100644 --- a/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py +++ b/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py @@ -1,41 +1,51 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.models.data_type import DataType +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedPDFSearchToolSchema(BaseModel): - """Input for PDFSearchTool.""" - query: str = Field(..., description="Mandatory query you want to use to search the PDF's content") + """Input for PDFSearchTool.""" + + query: str = Field( + ..., description="Mandatory query you want to use to search the PDF's content" + ) + class PDFSearchToolSchema(FixedPDFSearchToolSchema): - """Input for PDFSearchTool.""" - pdf: str = Field(..., description="Mandatory pdf path you want to search") + """Input for PDFSearchTool.""" + + pdf: str = Field(..., description="Mandatory pdf path you want to search") + class PDFSearchTool(RagTool): - name: str = "Search a PDF's content" - description: str = "A tool that can be used to semantic search a query from a PDF's content." - summarize: bool = False - args_schema: Type[BaseModel] = PDFSearchToolSchema - pdf: Optional[str] = None + name: str = "Search a PDF's content" + description: str = ( + "A tool that can be used to semantic search a query from a PDF's content." + ) + args_schema: Type[BaseModel] = PDFSearchToolSchema - def __init__(self, pdf: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if pdf is not None: - self.pdf = pdf - self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content." - self.args_schema = FixedPDFSearchToolSchema - self._generate_description() + def __init__(self, pdf: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if pdf is not None: + self.add(pdf) + self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content." + self.args_schema = FixedPDFSearchToolSchema - def _run( - self, - query: str, - **kwargs: Any, - ) -> Any: - pdf = kwargs.get('pdf', self.pdf) - self.app = App() - self.app.add(pdf, data_type=DataType.PDF_FILE) - return super()._run(query=query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = DataType.PDF_FILE + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "pdf" in kwargs: + self.add(kwargs["pdf"]) diff --git a/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py b/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py index 8b9707185..f22cac123 100644 --- a/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py +++ b/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py @@ -1,45 +1,37 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Type -from embedchain import App from embedchain.loaders.postgres import PostgresLoader +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool + class PGSearchToolSchema(BaseModel): - """Input for PGSearchTool.""" - search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content") + """Input for PGSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory semantic search query you want to use to search the database's content", + ) + class PGSearchTool(RagTool): - name: str = "Search a database's table content" - description: str = "A tool that can be used to semantic search a query from a database table's content." - summarize: bool = False - args_schema: Type[BaseModel] = PGSearchToolSchema - db_uri: str = Field(..., description="Mandatory database URI") - table_name: str = Field(..., description="Mandatory table name") - search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content") + name: str = "Search a database's table content" + description: str = "A tool that can be used to semantic search a query from a database table's content." + args_schema: Type[BaseModel] = PGSearchToolSchema + db_uri: str = Field(..., description="Mandatory database URI") - def __init__(self, table_name: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if table_name is not None: - self.table_name = table_name - self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content." - self._generate_description() - else: - raise('To use PGSearchTool, you must provide a `table_name` argument') + def __init__(self, table_name: str, **kwargs): + super().__init__(**kwargs) + self.add(table_name) + self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content." + self._generate_description() - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - - config = { "url": self.db_uri } - postgres_loader = PostgresLoader(config=config) - app = App() - app.add( - f"SELECT * FROM {self.table_name};", - data_type='postgres', - loader=postgres_loader - ) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + table_name: str, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = "postgres" + kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri)) + super().add(f"SELECT * FROM {table_name};", **kwargs) diff --git a/src/crewai_tools/tools/rag/README.md b/src/crewai_tools/tools/rag/README.md index c65daca16..b432a1a69 100644 --- a/src/crewai_tools/tools/rag/README.md +++ b/src/crewai_tools/tools/rag/README.md @@ -48,9 +48,6 @@ rag_tool = RagTool().from_directory('path/to/your/directory') # Example: Loading from a web page rag_tool = RagTool().from_web_page('https://example.com') - -# Example: Loading from an Embedchain configuration -rag_tool = RagTool().from_embedchain('path/to/your/config.json') ``` ## **Contribution** @@ -61,4 +58,4 @@ Contributions to RagTool and the broader CrewAI tools ecosystem are welcome. To RagTool is open-source and available under the MIT license. -Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better. \ No newline at end of file +Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better. diff --git a/src/crewai_tools/tools/rag/rag_tool.py b/src/crewai_tools/tools/rag/rag_tool.py index 3901129ff..97291cd81 100644 --- a/src/crewai_tools/tools/rag/rag_tool.py +++ b/src/crewai_tools/tools/rag/rag_tool.py @@ -1,38 +1,71 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any -from pydantic.v1 import BaseModel, ConfigDict +from pydantic import BaseModel, Field, model_validator from crewai_tools.tools.base_tool import BaseTool class Adapter(BaseModel, ABC): - model_config = ConfigDict(arbitrary_types_allowed=True) + class Config: + arbitrary_types_allowed = True @abstractmethod def query(self, question: str) -> str: """Query the knowledge base with a question and return the answer.""" + @abstractmethod + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + """Add content to the knowledge base.""" + + class RagTool(BaseTool): - model_config = ConfigDict(arbitrary_types_allowed=True) + class _AdapterPlaceholder(Adapter): + def query(self, question: str) -> str: + raise NotImplementedError + + def add(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError + name: str = "Knowledge base" description: str = "A knowledge base that can be used to answer questions." summarize: bool = False - adapter: Optional[Adapter] = None - app: Optional[Any] = None + adapter: Adapter = Field(default_factory=_AdapterPlaceholder) + config: dict[str, Any] | None = None + + @model_validator(mode="after") + def _set_default_adapter(self): + if isinstance(self.adapter, RagTool._AdapterPlaceholder): + from embedchain import App + + from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter + + app = App.from_config(config=self.config) if self.config else App() + self.adapter = EmbedchainAdapter( + embedchain_app=app, summarize=self.summarize + ) + + return self + + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.adapter.add(*args, **kwargs) def _run( self, query: str, + **kwargs: Any, ) -> Any: - from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter - self.adapter = EmbedchainAdapter(embedchain_app=self.app, summarize=self.summarize) + self._before_run(query, **kwargs) + return f"Relevant Content:\n{self.adapter.query(query)}" - def from_embedchain(self, config_path: str): - from embedchain import App - from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter - - app = App.from_config(config_path=config_path) - adapter = EmbedchainAdapter(embedchain_app=app) - return RagTool(name=self.name, description=self.description, adapter=adapter) \ No newline at end of file + def _before_run(self, query, **kwargs): + pass diff --git a/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py b/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py index 0a61eae53..375ba960a 100644 --- a/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py +++ b/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py @@ -1,40 +1,52 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.models.data_type import DataType +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool + class FixedTXTSearchToolSchema(BaseModel): - """Input for TXTSearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search the txt's content") + """Input for TXTSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search the txt's content", + ) + class TXTSearchToolSchema(FixedTXTSearchToolSchema): - """Input for TXTSearchTool.""" - txt: str = Field(..., description="Mandatory txt path you want to search") + """Input for TXTSearchTool.""" + + txt: str = Field(..., description="Mandatory txt path you want to search") + class TXTSearchTool(RagTool): - name: str = "Search a txt's content" - description: str = "A tool that can be used to semantic search a query from a txt's content." - summarize: bool = False - args_schema: Type[BaseModel] = TXTSearchToolSchema - txt: Optional[str] = None + name: str = "Search a txt's content" + description: str = ( + "A tool that can be used to semantic search a query from a txt's content." + ) + args_schema: Type[BaseModel] = TXTSearchToolSchema - def __init__(self, txt: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if txt is not None: - self.txt = txt - self.description = f"A tool that can be used to semantic search a query the {txt} txt's content." - self.args_schema = FixedTXTSearchToolSchema - self._generate_description() + def __init__(self, txt: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if txt is not None: + self.add(txt) + self.description = f"A tool that can be used to semantic search a query the {txt} txt's content." + self.args_schema = FixedTXTSearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - txt = kwargs.get('txt', self.txt) - self.app = App() - self.app.add(txt, data_type=DataType.TEXT_FILE) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = DataType.TEXT_FILE + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "txt" in kwargs: + self.add(kwargs["txt"]) diff --git a/src/crewai_tools/tools/website_search/website_search_tool.py b/src/crewai_tools/tools/website_search/website_search_tool.py index 37744f2b6..5768a6ccd 100644 --- a/src/crewai_tools/tools/website_search/website_search_tool.py +++ b/src/crewai_tools/tools/website_search/website_search_tool.py @@ -1,41 +1,52 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.models.data_type import DataType +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedWebsiteSearchToolSchema(BaseModel): - """Input for WebsiteSearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search a specific website") + """Input for WebsiteSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search a specific website", + ) + class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema): - """Input for WebsiteSearchTool.""" - website: str = Field(..., description="Mandatory valid website URL you want to search on") + """Input for WebsiteSearchTool.""" + + website: str = Field( + ..., description="Mandatory valid website URL you want to search on" + ) + class WebsiteSearchTool(RagTool): - name: str = "Search in a specific website" - description: str = "A tool that can be used to semantic search a query from a specific URL content." - summarize: bool = False - args_schema: Type[BaseModel] = WebsiteSearchToolSchema - website: Optional[str] = None + name: str = "Search in a specific website" + description: str = "A tool that can be used to semantic search a query from a specific URL content." + args_schema: Type[BaseModel] = WebsiteSearchToolSchema - def __init__(self, website: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if website is not None: - self.website = website - self.description = f"A tool that can be used to semantic search a query from {website} website content." - self.args_schema = FixedWebsiteSearchToolSchema - self._generate_description() + def __init__(self, website: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if website is not None: + self.add(website) + self.description = f"A tool that can be used to semantic search a query from {website} website content." + self.args_schema = FixedWebsiteSearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - website = kwargs.get('website', self.website) - self.app = App() - self.app.add(website, data_type=DataType.WEB_PAGE) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = DataType.WEB_PAGE + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "website" in kwargs: + self.add(kwargs["website"]) diff --git a/src/crewai_tools/tools/xml_search_tool/xml_search_tool.py b/src/crewai_tools/tools/xml_search_tool/xml_search_tool.py index 90cedfa56..4b3e445ea 100644 --- a/src/crewai_tools/tools/xml_search_tool/xml_search_tool.py +++ b/src/crewai_tools/tools/xml_search_tool/xml_search_tool.py @@ -1,41 +1,52 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.models.data_type import DataType +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedXMLSearchToolSchema(BaseModel): - """Input for XMLSearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search the XML's content") + """Input for XMLSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search the XML's content", + ) + class XMLSearchToolSchema(FixedXMLSearchToolSchema): - """Input for XMLSearchTool.""" - xml: str = Field(..., description="Mandatory xml path you want to search") + """Input for XMLSearchTool.""" + + xml: str = Field(..., description="Mandatory xml path you want to search") + class XMLSearchTool(RagTool): - name: str = "Search a XML's content" - description: str = "A tool that can be used to semantic search a query from a XML's content." - summarize: bool = False - args_schema: Type[BaseModel] = XMLSearchToolSchema - xml: Optional[str] = None + name: str = "Search a XML's content" + description: str = ( + "A tool that can be used to semantic search a query from a XML's content." + ) + args_schema: Type[BaseModel] = XMLSearchToolSchema - def __init__(self, xml: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if xml is not None: - self.xml = xml - self.description = f"A tool that can be used to semantic search a query the {xml} XML's content." - self.args_schema = FixedXMLSearchToolSchema - self._generate_description() + def __init__(self, xml: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if xml is not None: + self.add(xml) + self.description = f"A tool that can be used to semantic search a query the {xml} XML's content." + self.args_schema = FixedXMLSearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - xml = kwargs.get('xml', self.xml) - self.app = App() - self.app.add(xml, data_type=DataType.XML) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = DataType.XML + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "xml" in kwargs: + self.add(kwargs["xml"]) diff --git a/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py b/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py index fcdfe78c9..d3e4698c9 100644 --- a/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py +++ b/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py @@ -1,43 +1,55 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.models.data_type import DataType +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedYoutubeChannelSearchToolSchema(BaseModel): - """Input for YoutubeChannelSearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search the Youtube Channels content") + """Input for YoutubeChannelSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search the Youtube Channels content", + ) + class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema): - """Input for YoutubeChannelSearchTool.""" - youtube_channel_handle: str = Field(..., description="Mandatory youtube_channel_handle path you want to search") + """Input for YoutubeChannelSearchTool.""" + + youtube_channel_handle: str = Field( + ..., description="Mandatory youtube_channel_handle path you want to search" + ) + class YoutubeChannelSearchTool(RagTool): - name: str = "Search a Youtube Channels content" - description: str = "A tool that can be used to semantic search a query from a Youtube Channels content." - summarize: bool = False - args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema - youtube_channel_handle: Optional[str] = None + name: str = "Search a Youtube Channels content" + description: str = "A tool that can be used to semantic search a query from a Youtube Channels content." + args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema - def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if youtube_channel_handle is not None: - self.youtube_channel_handle = youtube_channel_handle - self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content." - self.args_schema = FixedYoutubeChannelSearchToolSchema - self._generate_description() + def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if youtube_channel_handle is not None: + self.add(youtube_channel_handle) + self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content." + self.args_schema = FixedYoutubeChannelSearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - youtube_channel_handle = kwargs.get('youtube_channel_handle', self.youtube_channel_handle) - if not youtube_channel_handle.startswith("@"): - youtube_channel_handle = f"@{youtube_channel_handle}" - self.app = App() - self.app.add(youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + youtube_channel_handle: str, + **kwargs: Any, + ) -> None: + if not youtube_channel_handle.startswith("@"): + youtube_channel_handle = f"@{youtube_channel_handle}" + + kwargs["data_type"] = DataType.YOUTUBE_CHANNEL + super().add(youtube_channel_handle, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "youtube_channel_handle" in kwargs: + self.add(kwargs["youtube_channel_handle"]) diff --git a/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py b/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py index 20aa9691d..f85457988 100644 --- a/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py +++ b/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py @@ -1,41 +1,52 @@ -from typing import Optional, Type, Any -from pydantic.v1 import BaseModel, Field +from typing import Any, Optional, Type -from embedchain import App from embedchain.models.data_type import DataType +from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool class FixedYoutubeVideoSearchToolSchema(BaseModel): - """Input for YoutubeVideoSearchTool.""" - search_query: str = Field(..., description="Mandatory search query you want to use to search the Youtube Video content") + """Input for YoutubeVideoSearchTool.""" + + search_query: str = Field( + ..., + description="Mandatory search query you want to use to search the Youtube Video content", + ) + class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema): - """Input for YoutubeVideoSearchTool.""" - youtube_video_url: str = Field(..., description="Mandatory youtube_video_url path you want to search") + """Input for YoutubeVideoSearchTool.""" + + youtube_video_url: str = Field( + ..., description="Mandatory youtube_video_url path you want to search" + ) + class YoutubeVideoSearchTool(RagTool): - name: str = "Search a Youtube Video content" - description: str = "A tool that can be used to semantic search a query from a Youtube Video content." - summarize: bool = False - args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema - youtube_video_url: Optional[str] = None + name: str = "Search a Youtube Video content" + description: str = "A tool that can be used to semantic search a query from a Youtube Video content." + args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema - def __init__(self, youtube_video_url: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if youtube_video_url is not None: - self.youtube_video_url = youtube_video_url - self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content." - self.args_schema = FixedYoutubeVideoSearchToolSchema - self._generate_description() + def __init__(self, youtube_video_url: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + if youtube_video_url is not None: + self.add(youtube_video_url) + self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content." + self.args_schema = FixedYoutubeVideoSearchToolSchema - def _run( - self, - search_query: str, - **kwargs: Any, - ) -> Any: - youtube_video_url = kwargs.get('youtube_video_url', self.youtube_video_url) - self.app = App() - self.app.add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO) - return super()._run(query=search_query) \ No newline at end of file + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + kwargs["data_type"] = DataType.YOUTUBE_VIDEO + super().add(*args, **kwargs) + + def _before_run( + self, + query: str, + **kwargs: Any, + ) -> Any: + if "youtube_video_url" in kwargs: + self.add(kwargs["youtube_video_url"]) diff --git a/tests/tools/rag/rag_tool_test.py b/tests/tools/rag/rag_tool_test.py new file mode 100644 index 000000000..42baccc2c --- /dev/null +++ b/tests/tools/rag/rag_tool_test.py @@ -0,0 +1,43 @@ +import os +from tempfile import NamedTemporaryFile +from typing import cast +from unittest import mock + +from pytest import fixture + +from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter +from crewai_tools.tools.rag.rag_tool import RagTool + + +@fixture(autouse=True) +def mock_embedchain_db_uri(): + with NamedTemporaryFile() as tmp: + uri = f"sqlite:///{tmp.name}" + with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}): + yield + + +def test_custom_llm_and_embedder(): + class MyTool(RagTool): + pass + + tool = MyTool( + config=dict( + llm=dict( + provider="openai", + config=dict(model="gpt-3.5-custom"), + ), + embedder=dict( + provider="openai", + config=dict(model="text-embedding-3-custom"), + ), + ) + ) + assert tool.adapter is not None + assert isinstance(tool.adapter, EmbedchainAdapter) + + adapter = cast(EmbedchainAdapter, tool.adapter) + assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom" + assert ( + adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom" + )