mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Custom model config for RAG tools
This commit is contained in:
@@ -1,12 +1,25 @@
|
||||
from typing import Any
|
||||
|
||||
from embedchain import App
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
|
||||
class EmbedchainAdapter(Adapter):
|
||||
embedchain_app: Any
|
||||
embedchain_app: App
|
||||
summarize: bool = False
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
result, sources = self.embedchain_app.query(question, citations=True, dry_run=(not self.summarize))
|
||||
result, sources = self.embedchain_app.query(
|
||||
question, citations=True, dry_run=(not self.summarize)
|
||||
)
|
||||
if self.summarize:
|
||||
return result
|
||||
return "\n\n".join([source[0] for source in sources])
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.embedchain_app.add(*args, **kwargs)
|
||||
|
||||
@@ -35,7 +35,7 @@ class LanceDBAdapter(Adapter):
|
||||
self._db = lancedb_connect(self.uri)
|
||||
self._table = self._db.open_table(self.table_name)
|
||||
|
||||
return super().model_post_init(__context)
|
||||
super().model_post_init(__context)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
query = self.embedding_function([question])[0]
|
||||
@@ -47,3 +47,10 @@ class LanceDBAdapter(Adapter):
|
||||
)
|
||||
values = [result[self.text_column_name] for result in results]
|
||||
return "\n".join(values)
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._table.add(*args, **kwargs)
|
||||
|
||||
@@ -1,28 +1,47 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, validator
|
||||
from pydantic.v1 import BaseModel as V1BaseModel
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
class _ArgsSchemaPlaceholder(V1BaseModel):
|
||||
pass
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Used to tell the model how/when/why to use the tool."""
|
||||
args_schema: Optional[Type[V1BaseModel]] = None
|
||||
args_schema: Type[V1BaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
|
||||
"""The schema for the arguments that the tool accepts."""
|
||||
description_updated: bool = False
|
||||
"""Flag to check if the description has been updated."""
|
||||
cache_function: Optional[Callable] = lambda: True
|
||||
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_args_schema(self):
|
||||
self._set_args_schema()
|
||||
@validator("args_schema", always=True, pre=True)
|
||||
def _default_args_schema(cls, v: Type[V1BaseModel]) -> Type[V1BaseModel]:
|
||||
if not isinstance(v, cls._ArgsSchemaPlaceholder):
|
||||
return v
|
||||
|
||||
return type(
|
||||
f"{cls.__name__}Schema",
|
||||
(V1BaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in cls._run.__annotations__.items() if k != "return"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._generate_description()
|
||||
return self
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
def run(
|
||||
self,
|
||||
@@ -57,16 +76,20 @@ class BaseTool(BaseModel, ABC):
|
||||
(V1BaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in self._run.__annotations__.items() if k != 'return'
|
||||
k: v
|
||||
for k, v in self._run.__annotations__.items()
|
||||
if k != "return"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _generate_description(self):
|
||||
args = []
|
||||
for arg, attribute in self.args_schema.schema()['properties'].items():
|
||||
args.append(f"{arg}: '{attribute['type']}'")
|
||||
for arg, attribute in self.args_schema.schema()["properties"].items():
|
||||
if "type" in attribute:
|
||||
args.append(f"{arg}: '{attribute['type']}'")
|
||||
|
||||
description = self.description.replace('\n', ' ')
|
||||
description = self.description.replace("\n", " ")
|
||||
self.description = f"{self.name}({', '.join(args)}) - {description}"
|
||||
|
||||
|
||||
@@ -93,19 +116,19 @@ def tool(*args):
|
||||
def _make_tool(f: Callable) -> BaseTool:
|
||||
if f.__doc__ is None:
|
||||
raise ValueError("Function must have a docstring")
|
||||
if f.__annotations__ is None:
|
||||
raise ValueError("Function must have type annotations")
|
||||
|
||||
args_schema = None
|
||||
if f.__annotations__:
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = type(
|
||||
class_name,
|
||||
(V1BaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in f.__annotations__.items() if k != 'return'
|
||||
},
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = type(
|
||||
class_name,
|
||||
(V1BaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in f.__annotations__.items() if k != "return"
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
return Tool(
|
||||
name=tool_name,
|
||||
@@ -120,4 +143,4 @@ def tool(*args):
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
if len(args) == 1 and isinstance(args[0], str):
|
||||
return _make_with_name(args[0])
|
||||
raise ValueError("Invalid arguments")
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
@@ -1,41 +1,52 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedCodeDocsSearchToolSchema(BaseModel):
|
||||
"""Input for CodeDocsSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the Code Docs content")
|
||||
"""Input for CodeDocsSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search the Code Docs content",
|
||||
)
|
||||
|
||||
|
||||
class CodeDocsSearchToolSchema(FixedCodeDocsSearchToolSchema):
|
||||
"""Input for CodeDocsSearchTool."""
|
||||
docs_url: str = Field(..., description="Mandatory docs_url path you want to search")
|
||||
"""Input for CodeDocsSearchTool."""
|
||||
|
||||
docs_url: str = Field(..., description="Mandatory docs_url path you want to search")
|
||||
|
||||
|
||||
class CodeDocsSearchTool(RagTool):
|
||||
name: str = "Search a Code Docs content"
|
||||
description: str = "A tool that can be used to semantic search a query from a Code Docs content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = CodeDocsSearchToolSchema
|
||||
docs_url: Optional[str] = None
|
||||
name: str = "Search a Code Docs content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Code Docs content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = CodeDocsSearchToolSchema
|
||||
|
||||
def __init__(self, docs_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docs_url is not None:
|
||||
self.docs_url = docs_url
|
||||
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
|
||||
self.args_schema = FixedCodeDocsSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, docs_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docs_url is not None:
|
||||
self.add(docs_url)
|
||||
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
|
||||
self.args_schema = FixedCodeDocsSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
docs_url = kwargs.get('docs_url', self.docs_url)
|
||||
self.app = App()
|
||||
self.app.add(docs_url, data_type=DataType.DOCS_SITE)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = DataType.DOCS_SITE
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "docs_url" in kwargs:
|
||||
self.add(kwargs["docs_url"])
|
||||
|
||||
@@ -1,41 +1,52 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedCSVSearchToolSchema(BaseModel):
|
||||
"""Input for CSVSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the CSV's content")
|
||||
"""Input for CSVSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search the CSV's content",
|
||||
)
|
||||
|
||||
|
||||
class CSVSearchToolSchema(FixedCSVSearchToolSchema):
|
||||
"""Input for CSVSearchTool."""
|
||||
csv: str = Field(..., description="Mandatory csv path you want to search")
|
||||
"""Input for CSVSearchTool."""
|
||||
|
||||
csv: str = Field(..., description="Mandatory csv path you want to search")
|
||||
|
||||
|
||||
class CSVSearchTool(RagTool):
|
||||
name: str = "Search a CSV's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a CSV's content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = CSVSearchToolSchema
|
||||
csv: Optional[str] = None
|
||||
name: str = "Search a CSV's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a CSV's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = CSVSearchToolSchema
|
||||
|
||||
def __init__(self, csv: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if csv is not None:
|
||||
self.csv = csv
|
||||
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
|
||||
self.args_schema = FixedCSVSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, csv: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if csv is not None:
|
||||
self.add(csv)
|
||||
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
|
||||
self.args_schema = FixedCSVSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
csv = kwargs.get('csv', self.csv)
|
||||
self.app = App()
|
||||
self.app.add(csv, data_type=DataType.CSV)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = DataType.CSV
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "csv" in kwargs:
|
||||
self.add(kwargs["csv"])
|
||||
|
||||
@@ -1,42 +1,52 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedDirectorySearchToolSchema(BaseModel):
|
||||
"""Input for DirectorySearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the directory's content")
|
||||
"""Input for DirectorySearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search the directory's content",
|
||||
)
|
||||
|
||||
|
||||
class DirectorySearchToolSchema(FixedDirectorySearchToolSchema):
|
||||
"""Input for DirectorySearchTool."""
|
||||
directory: str = Field(..., description="Mandatory directory you want to search")
|
||||
"""Input for DirectorySearchTool."""
|
||||
|
||||
directory: str = Field(..., description="Mandatory directory you want to search")
|
||||
|
||||
|
||||
class DirectorySearchTool(RagTool):
|
||||
name: str = "Search a directory's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a directory's content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = DirectorySearchToolSchema
|
||||
directory: Optional[str] = None
|
||||
name: str = "Search a directory's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a directory's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DirectorySearchToolSchema
|
||||
|
||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
self.directory = directory
|
||||
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
|
||||
self.args_schema = FixedDirectorySearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
|
||||
self.args_schema = FixedDirectorySearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
directory = kwargs.get('directory', self.directory)
|
||||
loader = DirectoryLoader(config=dict(recursive=True))
|
||||
self.app = App()
|
||||
self.app.add(directory, loader=loader)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["loader"] = DirectoryLoader(config=dict(recursive=True))
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "directory" in kwargs:
|
||||
self.add(kwargs["directory"])
|
||||
|
||||
@@ -1,41 +1,52 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedDOCXSearchToolSchema(BaseModel):
|
||||
"""Input for DOCXSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the DOCX's content")
|
||||
"""Input for DOCXSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search the DOCX's content",
|
||||
)
|
||||
|
||||
|
||||
class DOCXSearchToolSchema(FixedDOCXSearchToolSchema):
|
||||
"""Input for DOCXSearchTool."""
|
||||
docx: str = Field(..., description="Mandatory docx path you want to search")
|
||||
"""Input for DOCXSearchTool."""
|
||||
|
||||
docx: str = Field(..., description="Mandatory docx path you want to search")
|
||||
|
||||
|
||||
class DOCXSearchTool(RagTool):
|
||||
name: str = "Search a DOCX's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a DOCX's content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = DOCXSearchToolSchema
|
||||
docx: Optional[str] = None
|
||||
name: str = "Search a DOCX's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a DOCX's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DOCXSearchToolSchema
|
||||
|
||||
def __init__(self, docx: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docx is not None:
|
||||
self.docx = docx
|
||||
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
|
||||
self.args_schema = FixedDOCXSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, docx: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if docx is not None:
|
||||
self.add(docx)
|
||||
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
|
||||
self.args_schema = FixedDOCXSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
docx = kwargs.get('docx', self.docx)
|
||||
self.app = App()
|
||||
self.app.add(docx, data_type=DataType.DOCX)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = DataType.DOCX
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "docx" in kwargs:
|
||||
self.add(kwargs["docx"])
|
||||
|
||||
@@ -1,46 +1,58 @@
|
||||
from typing import Optional, Type, List, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.loaders.github import GithubLoader
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedGithubSearchToolSchema(BaseModel):
|
||||
"""Input for GithubSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the github repo's content")
|
||||
"""Input for GithubSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search the github repo's content",
|
||||
)
|
||||
|
||||
|
||||
class GithubSearchToolSchema(FixedGithubSearchToolSchema):
|
||||
"""Input for GithubSearchTool."""
|
||||
github_repo: str = Field(..., description="Mandatory github you want to search")
|
||||
content_types: List[str] = Field(..., description="Mandatory content types you want to be inlcuded search, options: [code, repo, pr, issue]")
|
||||
"""Input for GithubSearchTool."""
|
||||
|
||||
github_repo: str = Field(..., description="Mandatory github you want to search")
|
||||
content_types: List[str] = Field(
|
||||
...,
|
||||
description="Mandatory content types you want to be inlcuded search, options: [code, repo, pr, issue]",
|
||||
)
|
||||
|
||||
|
||||
class GithubSearchTool(RagTool):
|
||||
name: str = "Search a github repo's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a github repo's content."
|
||||
summarize: bool = False
|
||||
gh_token: str = None
|
||||
args_schema: Type[BaseModel] = GithubSearchToolSchema
|
||||
github_repo: Optional[str] = None
|
||||
content_types: List[str]
|
||||
name: str = "Search a github repo's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a github repo's content."
|
||||
summarize: bool = False
|
||||
gh_token: str
|
||||
args_schema: Type[BaseModel] = GithubSearchToolSchema
|
||||
content_types: List[str]
|
||||
|
||||
def __init__(self, github_repo: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if github_repo is not None:
|
||||
self.github_repo = github_repo
|
||||
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content."
|
||||
self.args_schema = FixedGithubSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, github_repo: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if github_repo is not None:
|
||||
self.add(github_repo)
|
||||
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content."
|
||||
self.args_schema = FixedGithubSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
github_repo = kwargs.get('github_repo', self.github_repo)
|
||||
loader = GithubLoader(config={"token": self.gh_token})
|
||||
app = App()
|
||||
app.add(f"repo:{github_repo} type:{','.join(self.content_types)}", data_type="github", loader=loader)
|
||||
self.app = app
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = "github"
|
||||
kwargs["loader"] = GithubLoader(config={"token": self.gh_token})
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "github_repo" in kwargs:
|
||||
self.add(kwargs["github_repo"])
|
||||
|
||||
@@ -1,41 +1,52 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedJSONSearchToolSchema(BaseModel):
|
||||
"""Input for JSONSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the JSON's content")
|
||||
"""Input for JSONSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search the JSON's content",
|
||||
)
|
||||
|
||||
|
||||
class JSONSearchToolSchema(FixedJSONSearchToolSchema):
|
||||
"""Input for JSONSearchTool."""
|
||||
json_path: str = Field(..., description="Mandatory json path you want to search")
|
||||
"""Input for JSONSearchTool."""
|
||||
|
||||
json_path: str = Field(..., description="Mandatory json path you want to search")
|
||||
|
||||
|
||||
class JSONSearchTool(RagTool):
|
||||
name: str = "Search a JSON's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a JSON's content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = JSONSearchToolSchema
|
||||
json_path: Optional[str] = None
|
||||
name: str = "Search a JSON's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a JSON's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = JSONSearchToolSchema
|
||||
|
||||
def __init__(self, json_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if json_path is not None:
|
||||
self.json_path = json_path
|
||||
self.description = f"A tool that can be used to semantic search a query the {json} JSON's content."
|
||||
self.args_schema = FixedJSONSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, json_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if json_path is not None:
|
||||
self.add(json_path)
|
||||
self.description = f"A tool that can be used to semantic search a query the {json_path} JSON's content."
|
||||
self.args_schema = FixedJSONSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
json_path = kwargs.get('json_path', self.json_path)
|
||||
self.app = App()
|
||||
self.app.add(json_path, data_type=DataType.JSON)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = DataType.JSON
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "json_path" in kwargs:
|
||||
self.add(kwargs["json_path"])
|
||||
|
||||
@@ -1,41 +1,52 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedMDXSearchToolSchema(BaseModel):
|
||||
"""Input for MDXSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the MDX's content")
|
||||
"""Input for MDXSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search the MDX's content",
|
||||
)
|
||||
|
||||
|
||||
class MDXSearchToolSchema(FixedMDXSearchToolSchema):
|
||||
"""Input for MDXSearchTool."""
|
||||
mdx: str = Field(..., description="Mandatory mdx path you want to search")
|
||||
"""Input for MDXSearchTool."""
|
||||
|
||||
mdx: str = Field(..., description="Mandatory mdx path you want to search")
|
||||
|
||||
|
||||
class MDXSearchTool(RagTool):
|
||||
name: str = "Search a MDX's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a MDX's content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = MDXSearchToolSchema
|
||||
mdx: Optional[str] = None
|
||||
name: str = "Search a MDX's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a MDX's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = MDXSearchToolSchema
|
||||
|
||||
def __init__(self, mdx: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if mdx is not None:
|
||||
self.mdx = mdx
|
||||
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
|
||||
self.args_schema = FixedMDXSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, mdx: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if mdx is not None:
|
||||
self.add(mdx)
|
||||
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
|
||||
self.args_schema = FixedMDXSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
mdx = kwargs.get('mdx', self.mdx)
|
||||
self.app = App()
|
||||
self.app.add(mdx, data_type=DataType.MDX)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = DataType.MDX
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "mdx" in kwargs:
|
||||
self.add(kwargs["mdx"])
|
||||
|
||||
@@ -1,41 +1,51 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedPDFSearchToolSchema(BaseModel):
|
||||
"""Input for PDFSearchTool."""
|
||||
query: str = Field(..., description="Mandatory query you want to use to search the PDF's content")
|
||||
"""Input for PDFSearchTool."""
|
||||
|
||||
query: str = Field(
|
||||
..., description="Mandatory query you want to use to search the PDF's content"
|
||||
)
|
||||
|
||||
|
||||
class PDFSearchToolSchema(FixedPDFSearchToolSchema):
|
||||
"""Input for PDFSearchTool."""
|
||||
pdf: str = Field(..., description="Mandatory pdf path you want to search")
|
||||
"""Input for PDFSearchTool."""
|
||||
|
||||
pdf: str = Field(..., description="Mandatory pdf path you want to search")
|
||||
|
||||
|
||||
class PDFSearchTool(RagTool):
|
||||
name: str = "Search a PDF's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a PDF's content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = PDFSearchToolSchema
|
||||
pdf: Optional[str] = None
|
||||
name: str = "Search a PDF's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a PDF's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = PDFSearchToolSchema
|
||||
|
||||
def __init__(self, pdf: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if pdf is not None:
|
||||
self.pdf = pdf
|
||||
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
|
||||
self.args_schema = FixedPDFSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, pdf: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if pdf is not None:
|
||||
self.add(pdf)
|
||||
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
|
||||
self.args_schema = FixedPDFSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
pdf = kwargs.get('pdf', self.pdf)
|
||||
self.app = App()
|
||||
self.app.add(pdf, data_type=DataType.PDF_FILE)
|
||||
return super()._run(query=query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = DataType.PDF_FILE
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "pdf" in kwargs:
|
||||
self.add(kwargs["pdf"])
|
||||
|
||||
@@ -1,45 +1,37 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.loaders.postgres import PostgresLoader
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class PGSearchToolSchema(BaseModel):
|
||||
"""Input for PGSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content")
|
||||
"""Input for PGSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory semantic search query you want to use to search the database's content",
|
||||
)
|
||||
|
||||
|
||||
class PGSearchTool(RagTool):
|
||||
name: str = "Search a database's table content"
|
||||
description: str = "A tool that can be used to semantic search a query from a database table's content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = PGSearchToolSchema
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
table_name: str = Field(..., description="Mandatory table name")
|
||||
search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content")
|
||||
name: str = "Search a database's table content"
|
||||
description: str = "A tool that can be used to semantic search a query from a database table's content."
|
||||
args_schema: Type[BaseModel] = PGSearchToolSchema
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
def __init__(self, table_name: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if table_name is not None:
|
||||
self.table_name = table_name
|
||||
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
||||
self._generate_description()
|
||||
else:
|
||||
raise('To use PGSearchTool, you must provide a `table_name` argument')
|
||||
def __init__(self, table_name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.add(table_name)
|
||||
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
||||
self._generate_description()
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
|
||||
config = { "url": self.db_uri }
|
||||
postgres_loader = PostgresLoader(config=config)
|
||||
app = App()
|
||||
app.add(
|
||||
f"SELECT * FROM {self.table_name};",
|
||||
data_type='postgres',
|
||||
loader=postgres_loader
|
||||
)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
table_name: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = "postgres"
|
||||
kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
|
||||
super().add(f"SELECT * FROM {table_name};", **kwargs)
|
||||
|
||||
@@ -48,9 +48,6 @@ rag_tool = RagTool().from_directory('path/to/your/directory')
|
||||
|
||||
# Example: Loading from a web page
|
||||
rag_tool = RagTool().from_web_page('https://example.com')
|
||||
|
||||
# Example: Loading from an Embedchain configuration
|
||||
rag_tool = RagTool().from_embedchain('path/to/your/config.json')
|
||||
```
|
||||
|
||||
## **Contribution**
|
||||
@@ -61,4 +58,4 @@ Contributions to RagTool and the broader CrewAI tools ecosystem are welcome. To
|
||||
|
||||
RagTool is open-source and available under the MIT license.
|
||||
|
||||
Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better.
|
||||
Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better.
|
||||
|
||||
@@ -1,38 +1,71 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic.v1 import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from crewai_tools.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def query(self, question: str) -> str:
|
||||
"""Query the knowledge base with a question and return the answer."""
|
||||
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add content to the knowledge base."""
|
||||
|
||||
|
||||
class RagTool(BaseTool):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
class _AdapterPlaceholder(Adapter):
|
||||
def query(self, question: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
name: str = "Knowledge base"
|
||||
description: str = "A knowledge base that can be used to answer questions."
|
||||
summarize: bool = False
|
||||
adapter: Optional[Adapter] = None
|
||||
app: Optional[Any] = None
|
||||
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_default_adapter(self):
|
||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||
from embedchain import App
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
app = App.from_config(config=self.config) if self.config else App()
|
||||
self.adapter = EmbedchainAdapter(
|
||||
embedchain_app=app, summarize=self.summarize
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.adapter.add(*args, **kwargs)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
self.adapter = EmbedchainAdapter(embedchain_app=self.app, summarize=self.summarize)
|
||||
self._before_run(query, **kwargs)
|
||||
|
||||
return f"Relevant Content:\n{self.adapter.query(query)}"
|
||||
|
||||
def from_embedchain(self, config_path: str):
|
||||
from embedchain import App
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
app = App.from_config(config_path=config_path)
|
||||
adapter = EmbedchainAdapter(embedchain_app=app)
|
||||
return RagTool(name=self.name, description=self.description, adapter=adapter)
|
||||
def _before_run(self, query, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -1,40 +1,52 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedTXTSearchToolSchema(BaseModel):
|
||||
"""Input for TXTSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the txt's content")
|
||||
"""Input for TXTSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search the txt's content",
|
||||
)
|
||||
|
||||
|
||||
class TXTSearchToolSchema(FixedTXTSearchToolSchema):
|
||||
"""Input for TXTSearchTool."""
|
||||
txt: str = Field(..., description="Mandatory txt path you want to search")
|
||||
"""Input for TXTSearchTool."""
|
||||
|
||||
txt: str = Field(..., description="Mandatory txt path you want to search")
|
||||
|
||||
|
||||
class TXTSearchTool(RagTool):
|
||||
name: str = "Search a txt's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a txt's content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = TXTSearchToolSchema
|
||||
txt: Optional[str] = None
|
||||
name: str = "Search a txt's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a txt's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = TXTSearchToolSchema
|
||||
|
||||
def __init__(self, txt: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if txt is not None:
|
||||
self.txt = txt
|
||||
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
|
||||
self.args_schema = FixedTXTSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, txt: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if txt is not None:
|
||||
self.add(txt)
|
||||
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
|
||||
self.args_schema = FixedTXTSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
txt = kwargs.get('txt', self.txt)
|
||||
self.app = App()
|
||||
self.app.add(txt, data_type=DataType.TEXT_FILE)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = DataType.TEXT_FILE
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "txt" in kwargs:
|
||||
self.add(kwargs["txt"])
|
||||
|
||||
@@ -1,41 +1,52 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedWebsiteSearchToolSchema(BaseModel):
|
||||
"""Input for WebsiteSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search a specific website")
|
||||
"""Input for WebsiteSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search a specific website",
|
||||
)
|
||||
|
||||
|
||||
class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
|
||||
"""Input for WebsiteSearchTool."""
|
||||
website: str = Field(..., description="Mandatory valid website URL you want to search on")
|
||||
"""Input for WebsiteSearchTool."""
|
||||
|
||||
website: str = Field(
|
||||
..., description="Mandatory valid website URL you want to search on"
|
||||
)
|
||||
|
||||
|
||||
class WebsiteSearchTool(RagTool):
|
||||
name: str = "Search in a specific website"
|
||||
description: str = "A tool that can be used to semantic search a query from a specific URL content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
|
||||
website: Optional[str] = None
|
||||
name: str = "Search in a specific website"
|
||||
description: str = "A tool that can be used to semantic search a query from a specific URL content."
|
||||
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
|
||||
|
||||
def __init__(self, website: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if website is not None:
|
||||
self.website = website
|
||||
self.description = f"A tool that can be used to semantic search a query from {website} website content."
|
||||
self.args_schema = FixedWebsiteSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, website: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if website is not None:
|
||||
self.add(website)
|
||||
self.description = f"A tool that can be used to semantic search a query from {website} website content."
|
||||
self.args_schema = FixedWebsiteSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
website = kwargs.get('website', self.website)
|
||||
self.app = App()
|
||||
self.app.add(website, data_type=DataType.WEB_PAGE)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = DataType.WEB_PAGE
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "website" in kwargs:
|
||||
self.add(kwargs["website"])
|
||||
|
||||
@@ -1,41 +1,52 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedXMLSearchToolSchema(BaseModel):
|
||||
"""Input for XMLSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the XML's content")
|
||||
"""Input for XMLSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search the XML's content",
|
||||
)
|
||||
|
||||
|
||||
class XMLSearchToolSchema(FixedXMLSearchToolSchema):
|
||||
"""Input for XMLSearchTool."""
|
||||
xml: str = Field(..., description="Mandatory xml path you want to search")
|
||||
"""Input for XMLSearchTool."""
|
||||
|
||||
xml: str = Field(..., description="Mandatory xml path you want to search")
|
||||
|
||||
|
||||
class XMLSearchTool(RagTool):
|
||||
name: str = "Search a XML's content"
|
||||
description: str = "A tool that can be used to semantic search a query from a XML's content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = XMLSearchToolSchema
|
||||
xml: Optional[str] = None
|
||||
name: str = "Search a XML's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a XML's content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = XMLSearchToolSchema
|
||||
|
||||
def __init__(self, xml: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if xml is not None:
|
||||
self.xml = xml
|
||||
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
|
||||
self.args_schema = FixedXMLSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, xml: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if xml is not None:
|
||||
self.add(xml)
|
||||
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
|
||||
self.args_schema = FixedXMLSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
xml = kwargs.get('xml', self.xml)
|
||||
self.app = App()
|
||||
self.app.add(xml, data_type=DataType.XML)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = DataType.XML
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "xml" in kwargs:
|
||||
self.add(kwargs["xml"])
|
||||
|
||||
@@ -1,43 +1,55 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedYoutubeChannelSearchToolSchema(BaseModel):
|
||||
"""Input for YoutubeChannelSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the Youtube Channels content")
|
||||
"""Input for YoutubeChannelSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search the Youtube Channels content",
|
||||
)
|
||||
|
||||
|
||||
class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
|
||||
"""Input for YoutubeChannelSearchTool."""
|
||||
youtube_channel_handle: str = Field(..., description="Mandatory youtube_channel_handle path you want to search")
|
||||
"""Input for YoutubeChannelSearchTool."""
|
||||
|
||||
youtube_channel_handle: str = Field(
|
||||
..., description="Mandatory youtube_channel_handle path you want to search"
|
||||
)
|
||||
|
||||
|
||||
class YoutubeChannelSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Channels content"
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
|
||||
youtube_channel_handle: Optional[str] = None
|
||||
name: str = "Search a Youtube Channels content"
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if youtube_channel_handle is not None:
|
||||
self.youtube_channel_handle = youtube_channel_handle
|
||||
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
|
||||
self.args_schema = FixedYoutubeChannelSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if youtube_channel_handle is not None:
|
||||
self.add(youtube_channel_handle)
|
||||
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
|
||||
self.args_schema = FixedYoutubeChannelSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
youtube_channel_handle = kwargs.get('youtube_channel_handle', self.youtube_channel_handle)
|
||||
if not youtube_channel_handle.startswith("@"):
|
||||
youtube_channel_handle = f"@{youtube_channel_handle}"
|
||||
self.app = App()
|
||||
self.app.add(youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
youtube_channel_handle: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if not youtube_channel_handle.startswith("@"):
|
||||
youtube_channel_handle = f"@{youtube_channel_handle}"
|
||||
|
||||
kwargs["data_type"] = DataType.YOUTUBE_CHANNEL
|
||||
super().add(youtube_channel_handle, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "youtube_channel_handle" in kwargs:
|
||||
self.add(kwargs["youtube_channel_handle"])
|
||||
|
||||
@@ -1,41 +1,52 @@
|
||||
from typing import Optional, Type, Any
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedYoutubeVideoSearchToolSchema(BaseModel):
|
||||
"""Input for YoutubeVideoSearchTool."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the Youtube Video content")
|
||||
"""Input for YoutubeVideoSearchTool."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Mandatory search query you want to use to search the Youtube Video content",
|
||||
)
|
||||
|
||||
|
||||
class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
|
||||
"""Input for YoutubeVideoSearchTool."""
|
||||
youtube_video_url: str = Field(..., description="Mandatory youtube_video_url path you want to search")
|
||||
"""Input for YoutubeVideoSearchTool."""
|
||||
|
||||
youtube_video_url: str = Field(
|
||||
..., description="Mandatory youtube_video_url path you want to search"
|
||||
)
|
||||
|
||||
|
||||
class YoutubeVideoSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Video content"
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
summarize: bool = False
|
||||
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
|
||||
youtube_video_url: Optional[str] = None
|
||||
name: str = "Search a Youtube Video content"
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if youtube_video_url is not None:
|
||||
self.youtube_video_url = youtube_video_url
|
||||
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
|
||||
self.args_schema = FixedYoutubeVideoSearchToolSchema
|
||||
self._generate_description()
|
||||
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if youtube_video_url is not None:
|
||||
self.add(youtube_video_url)
|
||||
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
|
||||
self.args_schema = FixedYoutubeVideoSearchToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
youtube_video_url = kwargs.get('youtube_video_url', self.youtube_video_url)
|
||||
self.app = App()
|
||||
self.app.add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
|
||||
return super()._run(query=search_query)
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs["data_type"] = DataType.YOUTUBE_VIDEO
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def _before_run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if "youtube_video_url" in kwargs:
|
||||
self.add(kwargs["youtube_video_url"])
|
||||
|
||||
43
tests/tools/rag/rag_tool_test.py
Normal file
43
tests/tools/rag/rag_tool_test.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import cast
|
||||
from unittest import mock
|
||||
|
||||
from pytest import fixture
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@fixture(autouse=True)
|
||||
def mock_embedchain_db_uri():
|
||||
with NamedTemporaryFile() as tmp:
|
||||
uri = f"sqlite:///{tmp.name}"
|
||||
with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}):
|
||||
yield
|
||||
|
||||
|
||||
def test_custom_llm_and_embedder():
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="openai",
|
||||
config=dict(model="gpt-3.5-custom"),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="openai",
|
||||
config=dict(model="text-embedding-3-custom"),
|
||||
),
|
||||
)
|
||||
)
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, EmbedchainAdapter)
|
||||
|
||||
adapter = cast(EmbedchainAdapter, tool.adapter)
|
||||
assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom"
|
||||
assert (
|
||||
adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom"
|
||||
)
|
||||
Reference in New Issue
Block a user