Custom model config for RAG tools

This commit is contained in:
Gui Vieira
2024-03-19 18:47:13 -03:00
parent 73cae1997d
commit 1c8d010601
20 changed files with 704 additions and 452 deletions

View File

@@ -1,12 +1,25 @@
from typing import Any
from embedchain import App
from crewai_tools.tools.rag.rag_tool import Adapter
class EmbedchainAdapter(Adapter):
embedchain_app: Any
embedchain_app: App
summarize: bool = False
def query(self, question: str) -> str:
result, sources = self.embedchain_app.query(question, citations=True, dry_run=(not self.summarize))
result, sources = self.embedchain_app.query(
question, citations=True, dry_run=(not self.summarize)
)
if self.summarize:
return result
return "\n\n".join([source[0] for source in sources])
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.embedchain_app.add(*args, **kwargs)

View File

@@ -35,7 +35,7 @@ class LanceDBAdapter(Adapter):
self._db = lancedb_connect(self.uri)
self._table = self._db.open_table(self.table_name)
return super().model_post_init(__context)
super().model_post_init(__context)
def query(self, question: str) -> str:
query = self.embedding_function([question])[0]
@@ -47,3 +47,10 @@ class LanceDBAdapter(Adapter):
)
values = [result[self.text_column_name] for result in results]
return "\n".join(values)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self._table.add(*args, **kwargs)

View File

@@ -1,28 +1,47 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Type
from pydantic import BaseModel, model_validator
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, ConfigDict, Field, validator
from pydantic.v1 import BaseModel as V1BaseModel
from langchain_core.tools import StructuredTool
class BaseTool(BaseModel, ABC):
class _ArgsSchemaPlaceholder(V1BaseModel):
pass
model_config = ConfigDict()
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool."""
args_schema: Optional[Type[V1BaseModel]] = None
args_schema: Type[V1BaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
"""The schema for the arguments that the tool accepts."""
description_updated: bool = False
"""Flag to check if the description has been updated."""
cache_function: Optional[Callable] = lambda: True
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
@model_validator(mode="after")
def _check_args_schema(self):
self._set_args_schema()
@validator("args_schema", always=True, pre=True)
def _default_args_schema(cls, v: Type[V1BaseModel]) -> Type[V1BaseModel]:
if not isinstance(v, cls._ArgsSchemaPlaceholder):
return v
return type(
f"{cls.__name__}Schema",
(V1BaseModel,),
{
"__annotations__": {
k: v for k, v in cls._run.__annotations__.items() if k != "return"
},
},
)
def model_post_init(self, __context: Any) -> None:
self._generate_description()
return self
super().model_post_init(__context)
def run(
self,
@@ -57,16 +76,20 @@ class BaseTool(BaseModel, ABC):
(V1BaseModel,),
{
"__annotations__": {
k: v for k, v in self._run.__annotations__.items() if k != 'return'
k: v
for k, v in self._run.__annotations__.items()
if k != "return"
},
},
)
def _generate_description(self):
args = []
for arg, attribute in self.args_schema.schema()['properties'].items():
args.append(f"{arg}: '{attribute['type']}'")
for arg, attribute in self.args_schema.schema()["properties"].items():
if "type" in attribute:
args.append(f"{arg}: '{attribute['type']}'")
description = self.description.replace('\n', ' ')
description = self.description.replace("\n", " ")
self.description = f"{self.name}({', '.join(args)}) - {description}"
@@ -93,19 +116,19 @@ def tool(*args):
def _make_tool(f: Callable) -> BaseTool:
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
if f.__annotations__ is None:
raise ValueError("Function must have type annotations")
args_schema = None
if f.__annotations__:
class_name = "".join(tool_name.split()).title()
args_schema = type(
class_name,
(V1BaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != 'return'
},
class_name = "".join(tool_name.split()).title()
args_schema = type(
class_name,
(V1BaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != "return"
},
)
},
)
return Tool(
name=tool_name,
@@ -120,4 +143,4 @@ def tool(*args):
return _make_with_name(args[0].__name__)(args[0])
if len(args) == 1 and isinstance(args[0], str):
return _make_with_name(args[0])
raise ValueError("Invalid arguments")
raise ValueError("Invalid arguments")

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedCodeDocsSearchToolSchema(BaseModel):
"""Input for CodeDocsSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the Code Docs content")
"""Input for CodeDocsSearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the Code Docs content",
)
class CodeDocsSearchToolSchema(FixedCodeDocsSearchToolSchema):
"""Input for CodeDocsSearchTool."""
docs_url: str = Field(..., description="Mandatory docs_url path you want to search")
"""Input for CodeDocsSearchTool."""
docs_url: str = Field(..., description="Mandatory docs_url path you want to search")
class CodeDocsSearchTool(RagTool):
name: str = "Search a Code Docs content"
description: str = "A tool that can be used to semantic search a query from a Code Docs content."
summarize: bool = False
args_schema: Type[BaseModel] = CodeDocsSearchToolSchema
docs_url: Optional[str] = None
name: str = "Search a Code Docs content"
description: str = (
"A tool that can be used to semantic search a query from a Code Docs content."
)
args_schema: Type[BaseModel] = CodeDocsSearchToolSchema
def __init__(self, docs_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if docs_url is not None:
self.docs_url = docs_url
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
self.args_schema = FixedCodeDocsSearchToolSchema
self._generate_description()
def __init__(self, docs_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if docs_url is not None:
self.add(docs_url)
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
self.args_schema = FixedCodeDocsSearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
docs_url = kwargs.get('docs_url', self.docs_url)
self.app = App()
self.app.add(docs_url, data_type=DataType.DOCS_SITE)
return super()._run(query=search_query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["data_type"] = DataType.DOCS_SITE
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "docs_url" in kwargs:
self.add(kwargs["docs_url"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedCSVSearchToolSchema(BaseModel):
"""Input for CSVSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the CSV's content")
"""Input for CSVSearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the CSV's content",
)
class CSVSearchToolSchema(FixedCSVSearchToolSchema):
"""Input for CSVSearchTool."""
csv: str = Field(..., description="Mandatory csv path you want to search")
"""Input for CSVSearchTool."""
csv: str = Field(..., description="Mandatory csv path you want to search")
class CSVSearchTool(RagTool):
name: str = "Search a CSV's content"
description: str = "A tool that can be used to semantic search a query from a CSV's content."
summarize: bool = False
args_schema: Type[BaseModel] = CSVSearchToolSchema
csv: Optional[str] = None
name: str = "Search a CSV's content"
description: str = (
"A tool that can be used to semantic search a query from a CSV's content."
)
args_schema: Type[BaseModel] = CSVSearchToolSchema
def __init__(self, csv: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if csv is not None:
self.csv = csv
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
self.args_schema = FixedCSVSearchToolSchema
self._generate_description()
def __init__(self, csv: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if csv is not None:
self.add(csv)
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
self.args_schema = FixedCSVSearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
csv = kwargs.get('csv', self.csv)
self.app = App()
self.app.add(csv, data_type=DataType.CSV)
return super()._run(query=search_query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["data_type"] = DataType.CSV
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "csv" in kwargs:
self.add(kwargs["csv"])

View File

@@ -1,42 +1,52 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.loaders.directory_loader import DirectoryLoader
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedDirectorySearchToolSchema(BaseModel):
"""Input for DirectorySearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the directory's content")
"""Input for DirectorySearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the directory's content",
)
class DirectorySearchToolSchema(FixedDirectorySearchToolSchema):
"""Input for DirectorySearchTool."""
directory: str = Field(..., description="Mandatory directory you want to search")
"""Input for DirectorySearchTool."""
directory: str = Field(..., description="Mandatory directory you want to search")
class DirectorySearchTool(RagTool):
name: str = "Search a directory's content"
description: str = "A tool that can be used to semantic search a query from a directory's content."
summarize: bool = False
args_schema: Type[BaseModel] = DirectorySearchToolSchema
directory: Optional[str] = None
name: str = "Search a directory's content"
description: str = (
"A tool that can be used to semantic search a query from a directory's content."
)
args_schema: Type[BaseModel] = DirectorySearchToolSchema
def __init__(self, directory: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if directory is not None:
self.directory = directory
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
self.args_schema = FixedDirectorySearchToolSchema
self._generate_description()
def __init__(self, directory: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if directory is not None:
self.add(directory)
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
self.args_schema = FixedDirectorySearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
directory = kwargs.get('directory', self.directory)
loader = DirectoryLoader(config=dict(recursive=True))
self.app = App()
self.app.add(directory, loader=loader)
return super()._run(query=search_query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["loader"] = DirectoryLoader(config=dict(recursive=True))
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "directory" in kwargs:
self.add(kwargs["directory"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedDOCXSearchToolSchema(BaseModel):
"""Input for DOCXSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the DOCX's content")
"""Input for DOCXSearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the DOCX's content",
)
class DOCXSearchToolSchema(FixedDOCXSearchToolSchema):
"""Input for DOCXSearchTool."""
docx: str = Field(..., description="Mandatory docx path you want to search")
"""Input for DOCXSearchTool."""
docx: str = Field(..., description="Mandatory docx path you want to search")
class DOCXSearchTool(RagTool):
name: str = "Search a DOCX's content"
description: str = "A tool that can be used to semantic search a query from a DOCX's content."
summarize: bool = False
args_schema: Type[BaseModel] = DOCXSearchToolSchema
docx: Optional[str] = None
name: str = "Search a DOCX's content"
description: str = (
"A tool that can be used to semantic search a query from a DOCX's content."
)
args_schema: Type[BaseModel] = DOCXSearchToolSchema
def __init__(self, docx: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if docx is not None:
self.docx = docx
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
self.args_schema = FixedDOCXSearchToolSchema
self._generate_description()
def __init__(self, docx: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if docx is not None:
self.add(docx)
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
self.args_schema = FixedDOCXSearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
docx = kwargs.get('docx', self.docx)
self.app = App()
self.app.add(docx, data_type=DataType.DOCX)
return super()._run(query=search_query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["data_type"] = DataType.DOCX
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "docx" in kwargs:
self.add(kwargs["docx"])

View File

@@ -1,46 +1,58 @@
from typing import Optional, Type, List, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, List, Optional, Type
from embedchain import App
from embedchain.loaders.github import GithubLoader
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedGithubSearchToolSchema(BaseModel):
"""Input for GithubSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the github repo's content")
"""Input for GithubSearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the github repo's content",
)
class GithubSearchToolSchema(FixedGithubSearchToolSchema):
"""Input for GithubSearchTool."""
github_repo: str = Field(..., description="Mandatory github you want to search")
content_types: List[str] = Field(..., description="Mandatory content types you want to be inlcuded search, options: [code, repo, pr, issue]")
"""Input for GithubSearchTool."""
github_repo: str = Field(..., description="Mandatory github you want to search")
content_types: List[str] = Field(
...,
description="Mandatory content types you want to be inlcuded search, options: [code, repo, pr, issue]",
)
class GithubSearchTool(RagTool):
name: str = "Search a github repo's content"
description: str = "A tool that can be used to semantic search a query from a github repo's content."
summarize: bool = False
gh_token: str = None
args_schema: Type[BaseModel] = GithubSearchToolSchema
github_repo: Optional[str] = None
content_types: List[str]
name: str = "Search a github repo's content"
description: str = "A tool that can be used to semantic search a query from a github repo's content."
summarize: bool = False
gh_token: str
args_schema: Type[BaseModel] = GithubSearchToolSchema
content_types: List[str]
def __init__(self, github_repo: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if github_repo is not None:
self.github_repo = github_repo
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content."
self.args_schema = FixedGithubSearchToolSchema
self._generate_description()
def __init__(self, github_repo: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if github_repo is not None:
self.add(github_repo)
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content."
self.args_schema = FixedGithubSearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
github_repo = kwargs.get('github_repo', self.github_repo)
loader = GithubLoader(config={"token": self.gh_token})
app = App()
app.add(f"repo:{github_repo} type:{','.join(self.content_types)}", data_type="github", loader=loader)
self.app = app
return super()._run(query=search_query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["data_type"] = "github"
kwargs["loader"] = GithubLoader(config={"token": self.gh_token})
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "github_repo" in kwargs:
self.add(kwargs["github_repo"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedJSONSearchToolSchema(BaseModel):
"""Input for JSONSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the JSON's content")
"""Input for JSONSearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the JSON's content",
)
class JSONSearchToolSchema(FixedJSONSearchToolSchema):
"""Input for JSONSearchTool."""
json_path: str = Field(..., description="Mandatory json path you want to search")
"""Input for JSONSearchTool."""
json_path: str = Field(..., description="Mandatory json path you want to search")
class JSONSearchTool(RagTool):
name: str = "Search a JSON's content"
description: str = "A tool that can be used to semantic search a query from a JSON's content."
summarize: bool = False
args_schema: Type[BaseModel] = JSONSearchToolSchema
json_path: Optional[str] = None
name: str = "Search a JSON's content"
description: str = (
"A tool that can be used to semantic search a query from a JSON's content."
)
args_schema: Type[BaseModel] = JSONSearchToolSchema
def __init__(self, json_path: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if json_path is not None:
self.json_path = json_path
self.description = f"A tool that can be used to semantic search a query the {json} JSON's content."
self.args_schema = FixedJSONSearchToolSchema
self._generate_description()
def __init__(self, json_path: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if json_path is not None:
self.add(json_path)
self.description = f"A tool that can be used to semantic search a query the {json_path} JSON's content."
self.args_schema = FixedJSONSearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
json_path = kwargs.get('json_path', self.json_path)
self.app = App()
self.app.add(json_path, data_type=DataType.JSON)
return super()._run(query=search_query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["data_type"] = DataType.JSON
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "json_path" in kwargs:
self.add(kwargs["json_path"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedMDXSearchToolSchema(BaseModel):
"""Input for MDXSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the MDX's content")
"""Input for MDXSearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the MDX's content",
)
class MDXSearchToolSchema(FixedMDXSearchToolSchema):
"""Input for MDXSearchTool."""
mdx: str = Field(..., description="Mandatory mdx path you want to search")
"""Input for MDXSearchTool."""
mdx: str = Field(..., description="Mandatory mdx path you want to search")
class MDXSearchTool(RagTool):
name: str = "Search a MDX's content"
description: str = "A tool that can be used to semantic search a query from a MDX's content."
summarize: bool = False
args_schema: Type[BaseModel] = MDXSearchToolSchema
mdx: Optional[str] = None
name: str = "Search a MDX's content"
description: str = (
"A tool that can be used to semantic search a query from a MDX's content."
)
args_schema: Type[BaseModel] = MDXSearchToolSchema
def __init__(self, mdx: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if mdx is not None:
self.mdx = mdx
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
self.args_schema = FixedMDXSearchToolSchema
self._generate_description()
def __init__(self, mdx: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if mdx is not None:
self.add(mdx)
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
self.args_schema = FixedMDXSearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
mdx = kwargs.get('mdx', self.mdx)
self.app = App()
self.app.add(mdx, data_type=DataType.MDX)
return super()._run(query=search_query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["data_type"] = DataType.MDX
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "mdx" in kwargs:
self.add(kwargs["mdx"])

View File

@@ -1,41 +1,51 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedPDFSearchToolSchema(BaseModel):
"""Input for PDFSearchTool."""
query: str = Field(..., description="Mandatory query you want to use to search the PDF's content")
"""Input for PDFSearchTool."""
query: str = Field(
..., description="Mandatory query you want to use to search the PDF's content"
)
class PDFSearchToolSchema(FixedPDFSearchToolSchema):
"""Input for PDFSearchTool."""
pdf: str = Field(..., description="Mandatory pdf path you want to search")
"""Input for PDFSearchTool."""
pdf: str = Field(..., description="Mandatory pdf path you want to search")
class PDFSearchTool(RagTool):
name: str = "Search a PDF's content"
description: str = "A tool that can be used to semantic search a query from a PDF's content."
summarize: bool = False
args_schema: Type[BaseModel] = PDFSearchToolSchema
pdf: Optional[str] = None
name: str = "Search a PDF's content"
description: str = (
"A tool that can be used to semantic search a query from a PDF's content."
)
args_schema: Type[BaseModel] = PDFSearchToolSchema
def __init__(self, pdf: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if pdf is not None:
self.pdf = pdf
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
self.args_schema = FixedPDFSearchToolSchema
self._generate_description()
def __init__(self, pdf: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if pdf is not None:
self.add(pdf)
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
self.args_schema = FixedPDFSearchToolSchema
def _run(
self,
query: str,
**kwargs: Any,
) -> Any:
pdf = kwargs.get('pdf', self.pdf)
self.app = App()
self.app.add(pdf, data_type=DataType.PDF_FILE)
return super()._run(query=query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["data_type"] = DataType.PDF_FILE
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "pdf" in kwargs:
self.add(kwargs["pdf"])

View File

@@ -1,45 +1,37 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Type
from embedchain import App
from embedchain.loaders.postgres import PostgresLoader
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class PGSearchToolSchema(BaseModel):
"""Input for PGSearchTool."""
search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content")
"""Input for PGSearchTool."""
search_query: str = Field(
...,
description="Mandatory semantic search query you want to use to search the database's content",
)
class PGSearchTool(RagTool):
name: str = "Search a database's table content"
description: str = "A tool that can be used to semantic search a query from a database table's content."
summarize: bool = False
args_schema: Type[BaseModel] = PGSearchToolSchema
db_uri: str = Field(..., description="Mandatory database URI")
table_name: str = Field(..., description="Mandatory table name")
search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content")
name: str = "Search a database's table content"
description: str = "A tool that can be used to semantic search a query from a database table's content."
args_schema: Type[BaseModel] = PGSearchToolSchema
db_uri: str = Field(..., description="Mandatory database URI")
def __init__(self, table_name: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if table_name is not None:
self.table_name = table_name
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
self._generate_description()
else:
raise('To use PGSearchTool, you must provide a `table_name` argument')
def __init__(self, table_name: str, **kwargs):
super().__init__(**kwargs)
self.add(table_name)
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
self._generate_description()
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
config = { "url": self.db_uri }
postgres_loader = PostgresLoader(config=config)
app = App()
app.add(
f"SELECT * FROM {self.table_name};",
data_type='postgres',
loader=postgres_loader
)
return super()._run(query=search_query)
def add(
self,
table_name: str,
**kwargs: Any,
) -> None:
kwargs["data_type"] = "postgres"
kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
super().add(f"SELECT * FROM {table_name};", **kwargs)

View File

@@ -48,9 +48,6 @@ rag_tool = RagTool().from_directory('path/to/your/directory')
# Example: Loading from a web page
rag_tool = RagTool().from_web_page('https://example.com')
# Example: Loading from an Embedchain configuration
rag_tool = RagTool().from_embedchain('path/to/your/config.json')
```
## **Contribution**
@@ -61,4 +58,4 @@ Contributions to RagTool and the broader CrewAI tools ecosystem are welcome. To
RagTool is open-source and available under the MIT license.
Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better.
Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better.

View File

@@ -1,38 +1,71 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any
from pydantic.v1 import BaseModel, ConfigDict
from pydantic import BaseModel, Field, model_validator
from crewai_tools.tools.base_tool import BaseTool
class Adapter(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
class Config:
arbitrary_types_allowed = True
@abstractmethod
def query(self, question: str) -> str:
"""Query the knowledge base with a question and return the answer."""
@abstractmethod
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
"""Add content to the knowledge base."""
class RagTool(BaseTool):
model_config = ConfigDict(arbitrary_types_allowed=True)
class _AdapterPlaceholder(Adapter):
def query(self, question: str) -> str:
raise NotImplementedError
def add(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError
name: str = "Knowledge base"
description: str = "A knowledge base that can be used to answer questions."
summarize: bool = False
adapter: Optional[Adapter] = None
app: Optional[Any] = None
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
config: dict[str, Any] | None = None
@model_validator(mode="after")
def _set_default_adapter(self):
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
from embedchain import App
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App.from_config(config=self.config) if self.config else App()
self.adapter = EmbedchainAdapter(
embedchain_app=app, summarize=self.summarize
)
return self
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.adapter.add(*args, **kwargs)
def _run(
self,
query: str,
**kwargs: Any,
) -> Any:
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
self.adapter = EmbedchainAdapter(embedchain_app=self.app, summarize=self.summarize)
self._before_run(query, **kwargs)
return f"Relevant Content:\n{self.adapter.query(query)}"
def from_embedchain(self, config_path: str):
from embedchain import App
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App.from_config(config_path=config_path)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(name=self.name, description=self.description, adapter=adapter)
def _before_run(self, query, **kwargs):
pass

View File

@@ -1,40 +1,52 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedTXTSearchToolSchema(BaseModel):
"""Input for TXTSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the txt's content")
"""Input for TXTSearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the txt's content",
)
class TXTSearchToolSchema(FixedTXTSearchToolSchema):
"""Input for TXTSearchTool."""
txt: str = Field(..., description="Mandatory txt path you want to search")
"""Input for TXTSearchTool."""
txt: str = Field(..., description="Mandatory txt path you want to search")
class TXTSearchTool(RagTool):
name: str = "Search a txt's content"
description: str = "A tool that can be used to semantic search a query from a txt's content."
summarize: bool = False
args_schema: Type[BaseModel] = TXTSearchToolSchema
txt: Optional[str] = None
name: str = "Search a txt's content"
description: str = (
"A tool that can be used to semantic search a query from a txt's content."
)
args_schema: Type[BaseModel] = TXTSearchToolSchema
def __init__(self, txt: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if txt is not None:
self.txt = txt
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
self.args_schema = FixedTXTSearchToolSchema
self._generate_description()
def __init__(self, txt: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if txt is not None:
self.add(txt)
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
self.args_schema = FixedTXTSearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
txt = kwargs.get('txt', self.txt)
self.app = App()
self.app.add(txt, data_type=DataType.TEXT_FILE)
return super()._run(query=search_query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["data_type"] = DataType.TEXT_FILE
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "txt" in kwargs:
self.add(kwargs["txt"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedWebsiteSearchToolSchema(BaseModel):
"""Input for WebsiteSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search a specific website")
"""Input for WebsiteSearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search a specific website",
)
class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
"""Input for WebsiteSearchTool."""
website: str = Field(..., description="Mandatory valid website URL you want to search on")
"""Input for WebsiteSearchTool."""
website: str = Field(
..., description="Mandatory valid website URL you want to search on"
)
class WebsiteSearchTool(RagTool):
name: str = "Search in a specific website"
description: str = "A tool that can be used to semantic search a query from a specific URL content."
summarize: bool = False
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
website: Optional[str] = None
name: str = "Search in a specific website"
description: str = "A tool that can be used to semantic search a query from a specific URL content."
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
def __init__(self, website: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if website is not None:
self.website = website
self.description = f"A tool that can be used to semantic search a query from {website} website content."
self.args_schema = FixedWebsiteSearchToolSchema
self._generate_description()
def __init__(self, website: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if website is not None:
self.add(website)
self.description = f"A tool that can be used to semantic search a query from {website} website content."
self.args_schema = FixedWebsiteSearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
website = kwargs.get('website', self.website)
self.app = App()
self.app.add(website, data_type=DataType.WEB_PAGE)
return super()._run(query=search_query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["data_type"] = DataType.WEB_PAGE
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "website" in kwargs:
self.add(kwargs["website"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedXMLSearchToolSchema(BaseModel):
"""Input for XMLSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the XML's content")
"""Input for XMLSearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the XML's content",
)
class XMLSearchToolSchema(FixedXMLSearchToolSchema):
"""Input for XMLSearchTool."""
xml: str = Field(..., description="Mandatory xml path you want to search")
"""Input for XMLSearchTool."""
xml: str = Field(..., description="Mandatory xml path you want to search")
class XMLSearchTool(RagTool):
name: str = "Search a XML's content"
description: str = "A tool that can be used to semantic search a query from a XML's content."
summarize: bool = False
args_schema: Type[BaseModel] = XMLSearchToolSchema
xml: Optional[str] = None
name: str = "Search a XML's content"
description: str = (
"A tool that can be used to semantic search a query from a XML's content."
)
args_schema: Type[BaseModel] = XMLSearchToolSchema
def __init__(self, xml: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if xml is not None:
self.xml = xml
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
self.args_schema = FixedXMLSearchToolSchema
self._generate_description()
def __init__(self, xml: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if xml is not None:
self.add(xml)
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
self.args_schema = FixedXMLSearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
xml = kwargs.get('xml', self.xml)
self.app = App()
self.app.add(xml, data_type=DataType.XML)
return super()._run(query=search_query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["data_type"] = DataType.XML
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "xml" in kwargs:
self.add(kwargs["xml"])

View File

@@ -1,43 +1,55 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedYoutubeChannelSearchToolSchema(BaseModel):
"""Input for YoutubeChannelSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the Youtube Channels content")
"""Input for YoutubeChannelSearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the Youtube Channels content",
)
class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
"""Input for YoutubeChannelSearchTool."""
youtube_channel_handle: str = Field(..., description="Mandatory youtube_channel_handle path you want to search")
"""Input for YoutubeChannelSearchTool."""
youtube_channel_handle: str = Field(
..., description="Mandatory youtube_channel_handle path you want to search"
)
class YoutubeChannelSearchTool(RagTool):
name: str = "Search a Youtube Channels content"
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
summarize: bool = False
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
youtube_channel_handle: Optional[str] = None
name: str = "Search a Youtube Channels content"
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if youtube_channel_handle is not None:
self.youtube_channel_handle = youtube_channel_handle
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
self.args_schema = FixedYoutubeChannelSearchToolSchema
self._generate_description()
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if youtube_channel_handle is not None:
self.add(youtube_channel_handle)
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
self.args_schema = FixedYoutubeChannelSearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
youtube_channel_handle = kwargs.get('youtube_channel_handle', self.youtube_channel_handle)
if not youtube_channel_handle.startswith("@"):
youtube_channel_handle = f"@{youtube_channel_handle}"
self.app = App()
self.app.add(youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL)
return super()._run(query=search_query)
def add(
self,
youtube_channel_handle: str,
**kwargs: Any,
) -> None:
if not youtube_channel_handle.startswith("@"):
youtube_channel_handle = f"@{youtube_channel_handle}"
kwargs["data_type"] = DataType.YOUTUBE_CHANNEL
super().add(youtube_channel_handle, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "youtube_channel_handle" in kwargs:
self.add(kwargs["youtube_channel_handle"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Optional, Type
from embedchain import App
from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class FixedYoutubeVideoSearchToolSchema(BaseModel):
"""Input for YoutubeVideoSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the Youtube Video content")
"""Input for YoutubeVideoSearchTool."""
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the Youtube Video content",
)
class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
"""Input for YoutubeVideoSearchTool."""
youtube_video_url: str = Field(..., description="Mandatory youtube_video_url path you want to search")
"""Input for YoutubeVideoSearchTool."""
youtube_video_url: str = Field(
..., description="Mandatory youtube_video_url path you want to search"
)
class YoutubeVideoSearchTool(RagTool):
name: str = "Search a Youtube Video content"
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
summarize: bool = False
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
youtube_video_url: Optional[str] = None
name: str = "Search a Youtube Video content"
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if youtube_video_url is not None:
self.youtube_video_url = youtube_video_url
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
self.args_schema = FixedYoutubeVideoSearchToolSchema
self._generate_description()
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if youtube_video_url is not None:
self.add(youtube_video_url)
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
self.args_schema = FixedYoutubeVideoSearchToolSchema
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
youtube_video_url = kwargs.get('youtube_video_url', self.youtube_video_url)
self.app = App()
self.app.add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
return super()._run(query=search_query)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
kwargs["data_type"] = DataType.YOUTUBE_VIDEO
super().add(*args, **kwargs)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "youtube_video_url" in kwargs:
self.add(kwargs["youtube_video_url"])

View File

@@ -0,0 +1,43 @@
import os
from tempfile import NamedTemporaryFile
from typing import cast
from unittest import mock
from pytest import fixture
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
from crewai_tools.tools.rag.rag_tool import RagTool
@fixture(autouse=True)
def mock_embedchain_db_uri():
with NamedTemporaryFile() as tmp:
uri = f"sqlite:///{tmp.name}"
with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}):
yield
def test_custom_llm_and_embedder():
class MyTool(RagTool):
pass
tool = MyTool(
config=dict(
llm=dict(
provider="openai",
config=dict(model="gpt-3.5-custom"),
),
embedder=dict(
provider="openai",
config=dict(model="text-embedding-3-custom"),
),
)
)
assert tool.adapter is not None
assert isinstance(tool.adapter, EmbedchainAdapter)
adapter = cast(EmbedchainAdapter, tool.adapter)
assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom"
assert (
adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom"
)