mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Custom model config for RAG tools
This commit is contained in:
@@ -1,12 +1,25 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from embedchain import App
|
||||||
|
|
||||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||||
|
|
||||||
|
|
||||||
class EmbedchainAdapter(Adapter):
|
class EmbedchainAdapter(Adapter):
|
||||||
embedchain_app: Any
|
embedchain_app: App
|
||||||
summarize: bool = False
|
summarize: bool = False
|
||||||
|
|
||||||
def query(self, question: str) -> str:
|
def query(self, question: str) -> str:
|
||||||
result, sources = self.embedchain_app.query(question, citations=True, dry_run=(not self.summarize))
|
result, sources = self.embedchain_app.query(
|
||||||
|
question, citations=True, dry_run=(not self.summarize)
|
||||||
|
)
|
||||||
if self.summarize:
|
if self.summarize:
|
||||||
return result
|
return result
|
||||||
return "\n\n".join([source[0] for source in sources])
|
return "\n\n".join([source[0] for source in sources])
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.embedchain_app.add(*args, **kwargs)
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class LanceDBAdapter(Adapter):
|
|||||||
self._db = lancedb_connect(self.uri)
|
self._db = lancedb_connect(self.uri)
|
||||||
self._table = self._db.open_table(self.table_name)
|
self._table = self._db.open_table(self.table_name)
|
||||||
|
|
||||||
return super().model_post_init(__context)
|
super().model_post_init(__context)
|
||||||
|
|
||||||
def query(self, question: str) -> str:
|
def query(self, question: str) -> str:
|
||||||
query = self.embedding_function([question])[0]
|
query = self.embedding_function([question])[0]
|
||||||
@@ -47,3 +47,10 @@ class LanceDBAdapter(Adapter):
|
|||||||
)
|
)
|
||||||
values = [result[self.text_column_name] for result in results]
|
values = [result[self.text_column_name] for result in results]
|
||||||
return "\n".join(values)
|
return "\n".join(values)
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._table.add(*args, **kwargs)
|
||||||
|
|||||||
@@ -1,28 +1,47 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Optional, Type
|
from typing import Any, Callable, Optional, Type
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from langchain_core.tools import StructuredTool
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, validator
|
||||||
from pydantic.v1 import BaseModel as V1BaseModel
|
from pydantic.v1 import BaseModel as V1BaseModel
|
||||||
|
|
||||||
from langchain_core.tools import StructuredTool
|
|
||||||
|
|
||||||
class BaseTool(BaseModel, ABC):
|
class BaseTool(BaseModel, ABC):
|
||||||
|
class _ArgsSchemaPlaceholder(V1BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_config = ConfigDict()
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
"""The unique name of the tool that clearly communicates its purpose."""
|
"""The unique name of the tool that clearly communicates its purpose."""
|
||||||
description: str
|
description: str
|
||||||
"""Used to tell the model how/when/why to use the tool."""
|
"""Used to tell the model how/when/why to use the tool."""
|
||||||
args_schema: Optional[Type[V1BaseModel]] = None
|
args_schema: Type[V1BaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
|
||||||
"""The schema for the arguments that the tool accepts."""
|
"""The schema for the arguments that the tool accepts."""
|
||||||
description_updated: bool = False
|
description_updated: bool = False
|
||||||
"""Flag to check if the description has been updated."""
|
"""Flag to check if the description has been updated."""
|
||||||
cache_function: Optional[Callable] = lambda: True
|
cache_function: Optional[Callable] = lambda: True
|
||||||
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
|
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@validator("args_schema", always=True, pre=True)
|
||||||
def _check_args_schema(self):
|
def _default_args_schema(cls, v: Type[V1BaseModel]) -> Type[V1BaseModel]:
|
||||||
self._set_args_schema()
|
if not isinstance(v, cls._ArgsSchemaPlaceholder):
|
||||||
|
return v
|
||||||
|
|
||||||
|
return type(
|
||||||
|
f"{cls.__name__}Schema",
|
||||||
|
(V1BaseModel,),
|
||||||
|
{
|
||||||
|
"__annotations__": {
|
||||||
|
k: v for k, v in cls._run.__annotations__.items() if k != "return"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
self._generate_description()
|
self._generate_description()
|
||||||
return self
|
|
||||||
|
super().model_post_init(__context)
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
@@ -57,16 +76,20 @@ class BaseTool(BaseModel, ABC):
|
|||||||
(V1BaseModel,),
|
(V1BaseModel,),
|
||||||
{
|
{
|
||||||
"__annotations__": {
|
"__annotations__": {
|
||||||
k: v for k, v in self._run.__annotations__.items() if k != 'return'
|
k: v
|
||||||
|
for k, v in self._run.__annotations__.items()
|
||||||
|
if k != "return"
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_description(self):
|
def _generate_description(self):
|
||||||
args = []
|
args = []
|
||||||
for arg, attribute in self.args_schema.schema()['properties'].items():
|
for arg, attribute in self.args_schema.schema()["properties"].items():
|
||||||
args.append(f"{arg}: '{attribute['type']}'")
|
if "type" in attribute:
|
||||||
|
args.append(f"{arg}: '{attribute['type']}'")
|
||||||
|
|
||||||
description = self.description.replace('\n', ' ')
|
description = self.description.replace("\n", " ")
|
||||||
self.description = f"{self.name}({', '.join(args)}) - {description}"
|
self.description = f"{self.name}({', '.join(args)}) - {description}"
|
||||||
|
|
||||||
|
|
||||||
@@ -93,19 +116,19 @@ def tool(*args):
|
|||||||
def _make_tool(f: Callable) -> BaseTool:
|
def _make_tool(f: Callable) -> BaseTool:
|
||||||
if f.__doc__ is None:
|
if f.__doc__ is None:
|
||||||
raise ValueError("Function must have a docstring")
|
raise ValueError("Function must have a docstring")
|
||||||
|
if f.__annotations__ is None:
|
||||||
|
raise ValueError("Function must have type annotations")
|
||||||
|
|
||||||
args_schema = None
|
class_name = "".join(tool_name.split()).title()
|
||||||
if f.__annotations__:
|
args_schema = type(
|
||||||
class_name = "".join(tool_name.split()).title()
|
class_name,
|
||||||
args_schema = type(
|
(V1BaseModel,),
|
||||||
class_name,
|
{
|
||||||
(V1BaseModel,),
|
"__annotations__": {
|
||||||
{
|
k: v for k, v in f.__annotations__.items() if k != "return"
|
||||||
"__annotations__": {
|
|
||||||
k: v for k, v in f.__annotations__.items() if k != 'return'
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return Tool(
|
return Tool(
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
@@ -120,4 +143,4 @@ def tool(*args):
|
|||||||
return _make_with_name(args[0].__name__)(args[0])
|
return _make_with_name(args[0].__name__)(args[0])
|
||||||
if len(args) == 1 and isinstance(args[0], str):
|
if len(args) == 1 and isinstance(args[0], str):
|
||||||
return _make_with_name(args[0])
|
return _make_with_name(args[0])
|
||||||
raise ValueError("Invalid arguments")
|
raise ValueError("Invalid arguments")
|
||||||
|
|||||||
@@ -1,41 +1,52 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedCodeDocsSearchToolSchema(BaseModel):
|
class FixedCodeDocsSearchToolSchema(BaseModel):
|
||||||
"""Input for CodeDocsSearchTool."""
|
"""Input for CodeDocsSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the Code Docs content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the Code Docs content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CodeDocsSearchToolSchema(FixedCodeDocsSearchToolSchema):
|
class CodeDocsSearchToolSchema(FixedCodeDocsSearchToolSchema):
|
||||||
"""Input for CodeDocsSearchTool."""
|
"""Input for CodeDocsSearchTool."""
|
||||||
docs_url: str = Field(..., description="Mandatory docs_url path you want to search")
|
|
||||||
|
docs_url: str = Field(..., description="Mandatory docs_url path you want to search")
|
||||||
|
|
||||||
|
|
||||||
class CodeDocsSearchTool(RagTool):
|
class CodeDocsSearchTool(RagTool):
|
||||||
name: str = "Search a Code Docs content"
|
name: str = "Search a Code Docs content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a Code Docs content."
|
description: str = (
|
||||||
summarize: bool = False
|
"A tool that can be used to semantic search a query from a Code Docs content."
|
||||||
args_schema: Type[BaseModel] = CodeDocsSearchToolSchema
|
)
|
||||||
docs_url: Optional[str] = None
|
args_schema: Type[BaseModel] = CodeDocsSearchToolSchema
|
||||||
|
|
||||||
def __init__(self, docs_url: Optional[str] = None, **kwargs):
|
def __init__(self, docs_url: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if docs_url is not None:
|
if docs_url is not None:
|
||||||
self.docs_url = docs_url
|
self.add(docs_url)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
|
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
|
||||||
self.args_schema = FixedCodeDocsSearchToolSchema
|
self.args_schema = FixedCodeDocsSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
docs_url = kwargs.get('docs_url', self.docs_url)
|
kwargs["data_type"] = DataType.DOCS_SITE
|
||||||
self.app = App()
|
super().add(*args, **kwargs)
|
||||||
self.app.add(docs_url, data_type=DataType.DOCS_SITE)
|
|
||||||
return super()._run(query=search_query)
|
def _before_run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "docs_url" in kwargs:
|
||||||
|
self.add(kwargs["docs_url"])
|
||||||
|
|||||||
@@ -1,41 +1,52 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedCSVSearchToolSchema(BaseModel):
|
class FixedCSVSearchToolSchema(BaseModel):
|
||||||
"""Input for CSVSearchTool."""
|
"""Input for CSVSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the CSV's content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the CSV's content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CSVSearchToolSchema(FixedCSVSearchToolSchema):
|
class CSVSearchToolSchema(FixedCSVSearchToolSchema):
|
||||||
"""Input for CSVSearchTool."""
|
"""Input for CSVSearchTool."""
|
||||||
csv: str = Field(..., description="Mandatory csv path you want to search")
|
|
||||||
|
csv: str = Field(..., description="Mandatory csv path you want to search")
|
||||||
|
|
||||||
|
|
||||||
class CSVSearchTool(RagTool):
|
class CSVSearchTool(RagTool):
|
||||||
name: str = "Search a CSV's content"
|
name: str = "Search a CSV's content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a CSV's content."
|
description: str = (
|
||||||
summarize: bool = False
|
"A tool that can be used to semantic search a query from a CSV's content."
|
||||||
args_schema: Type[BaseModel] = CSVSearchToolSchema
|
)
|
||||||
csv: Optional[str] = None
|
args_schema: Type[BaseModel] = CSVSearchToolSchema
|
||||||
|
|
||||||
def __init__(self, csv: Optional[str] = None, **kwargs):
|
def __init__(self, csv: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if csv is not None:
|
if csv is not None:
|
||||||
self.csv = csv
|
self.add(csv)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
|
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
|
||||||
self.args_schema = FixedCSVSearchToolSchema
|
self.args_schema = FixedCSVSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
csv = kwargs.get('csv', self.csv)
|
kwargs["data_type"] = DataType.CSV
|
||||||
self.app = App()
|
super().add(*args, **kwargs)
|
||||||
self.app.add(csv, data_type=DataType.CSV)
|
|
||||||
return super()._run(query=search_query)
|
def _before_run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "csv" in kwargs:
|
||||||
|
self.add(kwargs["csv"])
|
||||||
|
|||||||
@@ -1,42 +1,52 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedDirectorySearchToolSchema(BaseModel):
|
class FixedDirectorySearchToolSchema(BaseModel):
|
||||||
"""Input for DirectorySearchTool."""
|
"""Input for DirectorySearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the directory's content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the directory's content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DirectorySearchToolSchema(FixedDirectorySearchToolSchema):
|
class DirectorySearchToolSchema(FixedDirectorySearchToolSchema):
|
||||||
"""Input for DirectorySearchTool."""
|
"""Input for DirectorySearchTool."""
|
||||||
directory: str = Field(..., description="Mandatory directory you want to search")
|
|
||||||
|
directory: str = Field(..., description="Mandatory directory you want to search")
|
||||||
|
|
||||||
|
|
||||||
class DirectorySearchTool(RagTool):
|
class DirectorySearchTool(RagTool):
|
||||||
name: str = "Search a directory's content"
|
name: str = "Search a directory's content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a directory's content."
|
description: str = (
|
||||||
summarize: bool = False
|
"A tool that can be used to semantic search a query from a directory's content."
|
||||||
args_schema: Type[BaseModel] = DirectorySearchToolSchema
|
)
|
||||||
directory: Optional[str] = None
|
args_schema: Type[BaseModel] = DirectorySearchToolSchema
|
||||||
|
|
||||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if directory is not None:
|
if directory is not None:
|
||||||
self.directory = directory
|
self.add(directory)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
|
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
|
||||||
self.args_schema = FixedDirectorySearchToolSchema
|
self.args_schema = FixedDirectorySearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
directory = kwargs.get('directory', self.directory)
|
kwargs["loader"] = DirectoryLoader(config=dict(recursive=True))
|
||||||
loader = DirectoryLoader(config=dict(recursive=True))
|
super().add(*args, **kwargs)
|
||||||
self.app = App()
|
|
||||||
self.app.add(directory, loader=loader)
|
def _before_run(
|
||||||
return super()._run(query=search_query)
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "directory" in kwargs:
|
||||||
|
self.add(kwargs["directory"])
|
||||||
|
|||||||
@@ -1,41 +1,52 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedDOCXSearchToolSchema(BaseModel):
|
class FixedDOCXSearchToolSchema(BaseModel):
|
||||||
"""Input for DOCXSearchTool."""
|
"""Input for DOCXSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the DOCX's content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the DOCX's content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DOCXSearchToolSchema(FixedDOCXSearchToolSchema):
|
class DOCXSearchToolSchema(FixedDOCXSearchToolSchema):
|
||||||
"""Input for DOCXSearchTool."""
|
"""Input for DOCXSearchTool."""
|
||||||
docx: str = Field(..., description="Mandatory docx path you want to search")
|
|
||||||
|
docx: str = Field(..., description="Mandatory docx path you want to search")
|
||||||
|
|
||||||
|
|
||||||
class DOCXSearchTool(RagTool):
|
class DOCXSearchTool(RagTool):
|
||||||
name: str = "Search a DOCX's content"
|
name: str = "Search a DOCX's content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a DOCX's content."
|
description: str = (
|
||||||
summarize: bool = False
|
"A tool that can be used to semantic search a query from a DOCX's content."
|
||||||
args_schema: Type[BaseModel] = DOCXSearchToolSchema
|
)
|
||||||
docx: Optional[str] = None
|
args_schema: Type[BaseModel] = DOCXSearchToolSchema
|
||||||
|
|
||||||
def __init__(self, docx: Optional[str] = None, **kwargs):
|
def __init__(self, docx: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if docx is not None:
|
if docx is not None:
|
||||||
self.docx = docx
|
self.add(docx)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
|
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
|
||||||
self.args_schema = FixedDOCXSearchToolSchema
|
self.args_schema = FixedDOCXSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
docx = kwargs.get('docx', self.docx)
|
kwargs["data_type"] = DataType.DOCX
|
||||||
self.app = App()
|
super().add(*args, **kwargs)
|
||||||
self.app.add(docx, data_type=DataType.DOCX)
|
|
||||||
return super()._run(query=search_query)
|
def _before_run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "docx" in kwargs:
|
||||||
|
self.add(kwargs["docx"])
|
||||||
|
|||||||
@@ -1,46 +1,58 @@
|
|||||||
from typing import Optional, Type, List, Any
|
from typing import Any, List, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.loaders.github import GithubLoader
|
from embedchain.loaders.github import GithubLoader
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedGithubSearchToolSchema(BaseModel):
|
class FixedGithubSearchToolSchema(BaseModel):
|
||||||
"""Input for GithubSearchTool."""
|
"""Input for GithubSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the github repo's content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the github repo's content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GithubSearchToolSchema(FixedGithubSearchToolSchema):
|
class GithubSearchToolSchema(FixedGithubSearchToolSchema):
|
||||||
"""Input for GithubSearchTool."""
|
"""Input for GithubSearchTool."""
|
||||||
github_repo: str = Field(..., description="Mandatory github you want to search")
|
|
||||||
content_types: List[str] = Field(..., description="Mandatory content types you want to be inlcuded search, options: [code, repo, pr, issue]")
|
github_repo: str = Field(..., description="Mandatory github you want to search")
|
||||||
|
content_types: List[str] = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory content types you want to be inlcuded search, options: [code, repo, pr, issue]",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GithubSearchTool(RagTool):
|
class GithubSearchTool(RagTool):
|
||||||
name: str = "Search a github repo's content"
|
name: str = "Search a github repo's content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a github repo's content."
|
description: str = "A tool that can be used to semantic search a query from a github repo's content."
|
||||||
summarize: bool = False
|
summarize: bool = False
|
||||||
gh_token: str = None
|
gh_token: str
|
||||||
args_schema: Type[BaseModel] = GithubSearchToolSchema
|
args_schema: Type[BaseModel] = GithubSearchToolSchema
|
||||||
github_repo: Optional[str] = None
|
content_types: List[str]
|
||||||
content_types: List[str]
|
|
||||||
|
|
||||||
def __init__(self, github_repo: Optional[str] = None, **kwargs):
|
def __init__(self, github_repo: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if github_repo is not None:
|
if github_repo is not None:
|
||||||
self.github_repo = github_repo
|
self.add(github_repo)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content."
|
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content."
|
||||||
self.args_schema = FixedGithubSearchToolSchema
|
self.args_schema = FixedGithubSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
github_repo = kwargs.get('github_repo', self.github_repo)
|
kwargs["data_type"] = "github"
|
||||||
loader = GithubLoader(config={"token": self.gh_token})
|
kwargs["loader"] = GithubLoader(config={"token": self.gh_token})
|
||||||
app = App()
|
super().add(*args, **kwargs)
|
||||||
app.add(f"repo:{github_repo} type:{','.join(self.content_types)}", data_type="github", loader=loader)
|
|
||||||
self.app = app
|
def _before_run(
|
||||||
return super()._run(query=search_query)
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "github_repo" in kwargs:
|
||||||
|
self.add(kwargs["github_repo"])
|
||||||
|
|||||||
@@ -1,41 +1,52 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedJSONSearchToolSchema(BaseModel):
|
class FixedJSONSearchToolSchema(BaseModel):
|
||||||
"""Input for JSONSearchTool."""
|
"""Input for JSONSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the JSON's content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the JSON's content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JSONSearchToolSchema(FixedJSONSearchToolSchema):
|
class JSONSearchToolSchema(FixedJSONSearchToolSchema):
|
||||||
"""Input for JSONSearchTool."""
|
"""Input for JSONSearchTool."""
|
||||||
json_path: str = Field(..., description="Mandatory json path you want to search")
|
|
||||||
|
json_path: str = Field(..., description="Mandatory json path you want to search")
|
||||||
|
|
||||||
|
|
||||||
class JSONSearchTool(RagTool):
|
class JSONSearchTool(RagTool):
|
||||||
name: str = "Search a JSON's content"
|
name: str = "Search a JSON's content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a JSON's content."
|
description: str = (
|
||||||
summarize: bool = False
|
"A tool that can be used to semantic search a query from a JSON's content."
|
||||||
args_schema: Type[BaseModel] = JSONSearchToolSchema
|
)
|
||||||
json_path: Optional[str] = None
|
args_schema: Type[BaseModel] = JSONSearchToolSchema
|
||||||
|
|
||||||
def __init__(self, json_path: Optional[str] = None, **kwargs):
|
def __init__(self, json_path: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if json_path is not None:
|
if json_path is not None:
|
||||||
self.json_path = json_path
|
self.add(json_path)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {json} JSON's content."
|
self.description = f"A tool that can be used to semantic search a query the {json_path} JSON's content."
|
||||||
self.args_schema = FixedJSONSearchToolSchema
|
self.args_schema = FixedJSONSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
json_path = kwargs.get('json_path', self.json_path)
|
kwargs["data_type"] = DataType.JSON
|
||||||
self.app = App()
|
super().add(*args, **kwargs)
|
||||||
self.app.add(json_path, data_type=DataType.JSON)
|
|
||||||
return super()._run(query=search_query)
|
def _before_run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "json_path" in kwargs:
|
||||||
|
self.add(kwargs["json_path"])
|
||||||
|
|||||||
@@ -1,41 +1,52 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedMDXSearchToolSchema(BaseModel):
|
class FixedMDXSearchToolSchema(BaseModel):
|
||||||
"""Input for MDXSearchTool."""
|
"""Input for MDXSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the MDX's content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the MDX's content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MDXSearchToolSchema(FixedMDXSearchToolSchema):
|
class MDXSearchToolSchema(FixedMDXSearchToolSchema):
|
||||||
"""Input for MDXSearchTool."""
|
"""Input for MDXSearchTool."""
|
||||||
mdx: str = Field(..., description="Mandatory mdx path you want to search")
|
|
||||||
|
mdx: str = Field(..., description="Mandatory mdx path you want to search")
|
||||||
|
|
||||||
|
|
||||||
class MDXSearchTool(RagTool):
|
class MDXSearchTool(RagTool):
|
||||||
name: str = "Search a MDX's content"
|
name: str = "Search a MDX's content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a MDX's content."
|
description: str = (
|
||||||
summarize: bool = False
|
"A tool that can be used to semantic search a query from a MDX's content."
|
||||||
args_schema: Type[BaseModel] = MDXSearchToolSchema
|
)
|
||||||
mdx: Optional[str] = None
|
args_schema: Type[BaseModel] = MDXSearchToolSchema
|
||||||
|
|
||||||
def __init__(self, mdx: Optional[str] = None, **kwargs):
|
def __init__(self, mdx: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if mdx is not None:
|
if mdx is not None:
|
||||||
self.mdx = mdx
|
self.add(mdx)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
|
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
|
||||||
self.args_schema = FixedMDXSearchToolSchema
|
self.args_schema = FixedMDXSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
mdx = kwargs.get('mdx', self.mdx)
|
kwargs["data_type"] = DataType.MDX
|
||||||
self.app = App()
|
super().add(*args, **kwargs)
|
||||||
self.app.add(mdx, data_type=DataType.MDX)
|
|
||||||
return super()._run(query=search_query)
|
def _before_run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "mdx" in kwargs:
|
||||||
|
self.add(kwargs["mdx"])
|
||||||
|
|||||||
@@ -1,41 +1,51 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedPDFSearchToolSchema(BaseModel):
|
class FixedPDFSearchToolSchema(BaseModel):
|
||||||
"""Input for PDFSearchTool."""
|
"""Input for PDFSearchTool."""
|
||||||
query: str = Field(..., description="Mandatory query you want to use to search the PDF's content")
|
|
||||||
|
query: str = Field(
|
||||||
|
..., description="Mandatory query you want to use to search the PDF's content"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PDFSearchToolSchema(FixedPDFSearchToolSchema):
|
class PDFSearchToolSchema(FixedPDFSearchToolSchema):
|
||||||
"""Input for PDFSearchTool."""
|
"""Input for PDFSearchTool."""
|
||||||
pdf: str = Field(..., description="Mandatory pdf path you want to search")
|
|
||||||
|
pdf: str = Field(..., description="Mandatory pdf path you want to search")
|
||||||
|
|
||||||
|
|
||||||
class PDFSearchTool(RagTool):
|
class PDFSearchTool(RagTool):
|
||||||
name: str = "Search a PDF's content"
|
name: str = "Search a PDF's content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a PDF's content."
|
description: str = (
|
||||||
summarize: bool = False
|
"A tool that can be used to semantic search a query from a PDF's content."
|
||||||
args_schema: Type[BaseModel] = PDFSearchToolSchema
|
)
|
||||||
pdf: Optional[str] = None
|
args_schema: Type[BaseModel] = PDFSearchToolSchema
|
||||||
|
|
||||||
def __init__(self, pdf: Optional[str] = None, **kwargs):
|
def __init__(self, pdf: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if pdf is not None:
|
if pdf is not None:
|
||||||
self.pdf = pdf
|
self.add(pdf)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
|
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
|
||||||
self.args_schema = FixedPDFSearchToolSchema
|
self.args_schema = FixedPDFSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
pdf = kwargs.get('pdf', self.pdf)
|
kwargs["data_type"] = DataType.PDF_FILE
|
||||||
self.app = App()
|
super().add(*args, **kwargs)
|
||||||
self.app.add(pdf, data_type=DataType.PDF_FILE)
|
|
||||||
return super()._run(query=query)
|
def _before_run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "pdf" in kwargs:
|
||||||
|
self.add(kwargs["pdf"])
|
||||||
|
|||||||
@@ -1,45 +1,37 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.loaders.postgres import PostgresLoader
|
from embedchain.loaders.postgres import PostgresLoader
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class PGSearchToolSchema(BaseModel):
|
class PGSearchToolSchema(BaseModel):
|
||||||
"""Input for PGSearchTool."""
|
"""Input for PGSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory semantic search query you want to use to search the database's content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PGSearchTool(RagTool):
|
class PGSearchTool(RagTool):
|
||||||
name: str = "Search a database's table content"
|
name: str = "Search a database's table content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a database table's content."
|
description: str = "A tool that can be used to semantic search a query from a database table's content."
|
||||||
summarize: bool = False
|
args_schema: Type[BaseModel] = PGSearchToolSchema
|
||||||
args_schema: Type[BaseModel] = PGSearchToolSchema
|
db_uri: str = Field(..., description="Mandatory database URI")
|
||||||
db_uri: str = Field(..., description="Mandatory database URI")
|
|
||||||
table_name: str = Field(..., description="Mandatory table name")
|
|
||||||
search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content")
|
|
||||||
|
|
||||||
def __init__(self, table_name: Optional[str] = None, **kwargs):
|
def __init__(self, table_name: str, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if table_name is not None:
|
self.add(table_name)
|
||||||
self.table_name = table_name
|
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
||||||
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
self._generate_description()
|
||||||
self._generate_description()
|
|
||||||
else:
|
|
||||||
raise('To use PGSearchTool, you must provide a `table_name` argument')
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
table_name: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
|
kwargs["data_type"] = "postgres"
|
||||||
config = { "url": self.db_uri }
|
kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
|
||||||
postgres_loader = PostgresLoader(config=config)
|
super().add(f"SELECT * FROM {table_name};", **kwargs)
|
||||||
app = App()
|
|
||||||
app.add(
|
|
||||||
f"SELECT * FROM {self.table_name};",
|
|
||||||
data_type='postgres',
|
|
||||||
loader=postgres_loader
|
|
||||||
)
|
|
||||||
return super()._run(query=search_query)
|
|
||||||
|
|||||||
@@ -48,9 +48,6 @@ rag_tool = RagTool().from_directory('path/to/your/directory')
|
|||||||
|
|
||||||
# Example: Loading from a web page
|
# Example: Loading from a web page
|
||||||
rag_tool = RagTool().from_web_page('https://example.com')
|
rag_tool = RagTool().from_web_page('https://example.com')
|
||||||
|
|
||||||
# Example: Loading from an Embedchain configuration
|
|
||||||
rag_tool = RagTool().from_embedchain('path/to/your/config.json')
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## **Contribution**
|
## **Contribution**
|
||||||
@@ -61,4 +58,4 @@ Contributions to RagTool and the broader CrewAI tools ecosystem are welcome. To
|
|||||||
|
|
||||||
RagTool is open-source and available under the MIT license.
|
RagTool is open-source and available under the MIT license.
|
||||||
|
|
||||||
Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better.
|
Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better.
|
||||||
|
|||||||
@@ -1,38 +1,71 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic.v1 import BaseModel, ConfigDict
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from crewai_tools.tools.base_tool import BaseTool
|
from crewai_tools.tools.base_tool import BaseTool
|
||||||
|
|
||||||
|
|
||||||
class Adapter(BaseModel, ABC):
|
class Adapter(BaseModel, ABC):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def query(self, question: str) -> str:
|
def query(self, question: str) -> str:
|
||||||
"""Query the knowledge base with a question and return the answer."""
|
"""Query the knowledge base with a question and return the answer."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Add content to the knowledge base."""
|
||||||
|
|
||||||
|
|
||||||
class RagTool(BaseTool):
|
class RagTool(BaseTool):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
class _AdapterPlaceholder(Adapter):
|
||||||
|
def query(self, question: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
name: str = "Knowledge base"
|
name: str = "Knowledge base"
|
||||||
description: str = "A knowledge base that can be used to answer questions."
|
description: str = "A knowledge base that can be used to answer questions."
|
||||||
summarize: bool = False
|
summarize: bool = False
|
||||||
adapter: Optional[Adapter] = None
|
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
||||||
app: Optional[Any] = None
|
config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _set_default_adapter(self):
|
||||||
|
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||||
|
from embedchain import App
|
||||||
|
|
||||||
|
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||||
|
|
||||||
|
app = App.from_config(config=self.config) if self.config else App()
|
||||||
|
self.adapter = EmbedchainAdapter(
|
||||||
|
embedchain_app=app, summarize=self.summarize
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.adapter.add(*args, **kwargs)
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
self._before_run(query, **kwargs)
|
||||||
self.adapter = EmbedchainAdapter(embedchain_app=self.app, summarize=self.summarize)
|
|
||||||
return f"Relevant Content:\n{self.adapter.query(query)}"
|
return f"Relevant Content:\n{self.adapter.query(query)}"
|
||||||
|
|
||||||
def from_embedchain(self, config_path: str):
|
def _before_run(self, query, **kwargs):
|
||||||
from embedchain import App
|
pass
|
||||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
|
||||||
|
|
||||||
app = App.from_config(config_path=config_path)
|
|
||||||
adapter = EmbedchainAdapter(embedchain_app=app)
|
|
||||||
return RagTool(name=self.name, description=self.description, adapter=adapter)
|
|
||||||
|
|||||||
@@ -1,40 +1,52 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedTXTSearchToolSchema(BaseModel):
|
class FixedTXTSearchToolSchema(BaseModel):
|
||||||
"""Input for TXTSearchTool."""
|
"""Input for TXTSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the txt's content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the txt's content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TXTSearchToolSchema(FixedTXTSearchToolSchema):
|
class TXTSearchToolSchema(FixedTXTSearchToolSchema):
|
||||||
"""Input for TXTSearchTool."""
|
"""Input for TXTSearchTool."""
|
||||||
txt: str = Field(..., description="Mandatory txt path you want to search")
|
|
||||||
|
txt: str = Field(..., description="Mandatory txt path you want to search")
|
||||||
|
|
||||||
|
|
||||||
class TXTSearchTool(RagTool):
|
class TXTSearchTool(RagTool):
|
||||||
name: str = "Search a txt's content"
|
name: str = "Search a txt's content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a txt's content."
|
description: str = (
|
||||||
summarize: bool = False
|
"A tool that can be used to semantic search a query from a txt's content."
|
||||||
args_schema: Type[BaseModel] = TXTSearchToolSchema
|
)
|
||||||
txt: Optional[str] = None
|
args_schema: Type[BaseModel] = TXTSearchToolSchema
|
||||||
|
|
||||||
def __init__(self, txt: Optional[str] = None, **kwargs):
|
def __init__(self, txt: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if txt is not None:
|
if txt is not None:
|
||||||
self.txt = txt
|
self.add(txt)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
|
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
|
||||||
self.args_schema = FixedTXTSearchToolSchema
|
self.args_schema = FixedTXTSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
txt = kwargs.get('txt', self.txt)
|
kwargs["data_type"] = DataType.TEXT_FILE
|
||||||
self.app = App()
|
super().add(*args, **kwargs)
|
||||||
self.app.add(txt, data_type=DataType.TEXT_FILE)
|
|
||||||
return super()._run(query=search_query)
|
def _before_run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "txt" in kwargs:
|
||||||
|
self.add(kwargs["txt"])
|
||||||
|
|||||||
@@ -1,41 +1,52 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedWebsiteSearchToolSchema(BaseModel):
|
class FixedWebsiteSearchToolSchema(BaseModel):
|
||||||
"""Input for WebsiteSearchTool."""
|
"""Input for WebsiteSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search a specific website")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search a specific website",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
|
class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
|
||||||
"""Input for WebsiteSearchTool."""
|
"""Input for WebsiteSearchTool."""
|
||||||
website: str = Field(..., description="Mandatory valid website URL you want to search on")
|
|
||||||
|
website: str = Field(
|
||||||
|
..., description="Mandatory valid website URL you want to search on"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WebsiteSearchTool(RagTool):
|
class WebsiteSearchTool(RagTool):
|
||||||
name: str = "Search in a specific website"
|
name: str = "Search in a specific website"
|
||||||
description: str = "A tool that can be used to semantic search a query from a specific URL content."
|
description: str = "A tool that can be used to semantic search a query from a specific URL content."
|
||||||
summarize: bool = False
|
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
|
||||||
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
|
|
||||||
website: Optional[str] = None
|
|
||||||
|
|
||||||
def __init__(self, website: Optional[str] = None, **kwargs):
|
def __init__(self, website: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if website is not None:
|
if website is not None:
|
||||||
self.website = website
|
self.add(website)
|
||||||
self.description = f"A tool that can be used to semantic search a query from {website} website content."
|
self.description = f"A tool that can be used to semantic search a query from {website} website content."
|
||||||
self.args_schema = FixedWebsiteSearchToolSchema
|
self.args_schema = FixedWebsiteSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
website = kwargs.get('website', self.website)
|
kwargs["data_type"] = DataType.WEB_PAGE
|
||||||
self.app = App()
|
super().add(*args, **kwargs)
|
||||||
self.app.add(website, data_type=DataType.WEB_PAGE)
|
|
||||||
return super()._run(query=search_query)
|
def _before_run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "website" in kwargs:
|
||||||
|
self.add(kwargs["website"])
|
||||||
|
|||||||
@@ -1,41 +1,52 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedXMLSearchToolSchema(BaseModel):
|
class FixedXMLSearchToolSchema(BaseModel):
|
||||||
"""Input for XMLSearchTool."""
|
"""Input for XMLSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the XML's content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the XML's content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class XMLSearchToolSchema(FixedXMLSearchToolSchema):
|
class XMLSearchToolSchema(FixedXMLSearchToolSchema):
|
||||||
"""Input for XMLSearchTool."""
|
"""Input for XMLSearchTool."""
|
||||||
xml: str = Field(..., description="Mandatory xml path you want to search")
|
|
||||||
|
xml: str = Field(..., description="Mandatory xml path you want to search")
|
||||||
|
|
||||||
|
|
||||||
class XMLSearchTool(RagTool):
|
class XMLSearchTool(RagTool):
|
||||||
name: str = "Search a XML's content"
|
name: str = "Search a XML's content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a XML's content."
|
description: str = (
|
||||||
summarize: bool = False
|
"A tool that can be used to semantic search a query from a XML's content."
|
||||||
args_schema: Type[BaseModel] = XMLSearchToolSchema
|
)
|
||||||
xml: Optional[str] = None
|
args_schema: Type[BaseModel] = XMLSearchToolSchema
|
||||||
|
|
||||||
def __init__(self, xml: Optional[str] = None, **kwargs):
|
def __init__(self, xml: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if xml is not None:
|
if xml is not None:
|
||||||
self.xml = xml
|
self.add(xml)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
|
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
|
||||||
self.args_schema = FixedXMLSearchToolSchema
|
self.args_schema = FixedXMLSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
xml = kwargs.get('xml', self.xml)
|
kwargs["data_type"] = DataType.XML
|
||||||
self.app = App()
|
super().add(*args, **kwargs)
|
||||||
self.app.add(xml, data_type=DataType.XML)
|
|
||||||
return super()._run(query=search_query)
|
def _before_run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "xml" in kwargs:
|
||||||
|
self.add(kwargs["xml"])
|
||||||
|
|||||||
@@ -1,43 +1,55 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedYoutubeChannelSearchToolSchema(BaseModel):
|
class FixedYoutubeChannelSearchToolSchema(BaseModel):
|
||||||
"""Input for YoutubeChannelSearchTool."""
|
"""Input for YoutubeChannelSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the Youtube Channels content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the Youtube Channels content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
|
class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
|
||||||
"""Input for YoutubeChannelSearchTool."""
|
"""Input for YoutubeChannelSearchTool."""
|
||||||
youtube_channel_handle: str = Field(..., description="Mandatory youtube_channel_handle path you want to search")
|
|
||||||
|
youtube_channel_handle: str = Field(
|
||||||
|
..., description="Mandatory youtube_channel_handle path you want to search"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class YoutubeChannelSearchTool(RagTool):
|
class YoutubeChannelSearchTool(RagTool):
|
||||||
name: str = "Search a Youtube Channels content"
|
name: str = "Search a Youtube Channels content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
|
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||||
summarize: bool = False
|
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
|
||||||
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
|
|
||||||
youtube_channel_handle: Optional[str] = None
|
|
||||||
|
|
||||||
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
|
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if youtube_channel_handle is not None:
|
if youtube_channel_handle is not None:
|
||||||
self.youtube_channel_handle = youtube_channel_handle
|
self.add(youtube_channel_handle)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
|
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
|
||||||
self.args_schema = FixedYoutubeChannelSearchToolSchema
|
self.args_schema = FixedYoutubeChannelSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
youtube_channel_handle: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
youtube_channel_handle = kwargs.get('youtube_channel_handle', self.youtube_channel_handle)
|
if not youtube_channel_handle.startswith("@"):
|
||||||
if not youtube_channel_handle.startswith("@"):
|
youtube_channel_handle = f"@{youtube_channel_handle}"
|
||||||
youtube_channel_handle = f"@{youtube_channel_handle}"
|
|
||||||
self.app = App()
|
kwargs["data_type"] = DataType.YOUTUBE_CHANNEL
|
||||||
self.app.add(youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL)
|
super().add(youtube_channel_handle, **kwargs)
|
||||||
return super()._run(query=search_query)
|
|
||||||
|
def _before_run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "youtube_channel_handle" in kwargs:
|
||||||
|
self.add(kwargs["youtube_channel_handle"])
|
||||||
|
|||||||
@@ -1,41 +1,52 @@
|
|||||||
from typing import Optional, Type, Any
|
from typing import Any, Optional, Type
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
from ..rag.rag_tool import RagTool
|
from ..rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
class FixedYoutubeVideoSearchToolSchema(BaseModel):
|
class FixedYoutubeVideoSearchToolSchema(BaseModel):
|
||||||
"""Input for YoutubeVideoSearchTool."""
|
"""Input for YoutubeVideoSearchTool."""
|
||||||
search_query: str = Field(..., description="Mandatory search query you want to use to search the Youtube Video content")
|
|
||||||
|
search_query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Mandatory search query you want to use to search the Youtube Video content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
|
class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
|
||||||
"""Input for YoutubeVideoSearchTool."""
|
"""Input for YoutubeVideoSearchTool."""
|
||||||
youtube_video_url: str = Field(..., description="Mandatory youtube_video_url path you want to search")
|
|
||||||
|
youtube_video_url: str = Field(
|
||||||
|
..., description="Mandatory youtube_video_url path you want to search"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class YoutubeVideoSearchTool(RagTool):
|
class YoutubeVideoSearchTool(RagTool):
|
||||||
name: str = "Search a Youtube Video content"
|
name: str = "Search a Youtube Video content"
|
||||||
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
|
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
|
||||||
summarize: bool = False
|
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
|
||||||
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
|
|
||||||
youtube_video_url: Optional[str] = None
|
|
||||||
|
|
||||||
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
|
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if youtube_video_url is not None:
|
if youtube_video_url is not None:
|
||||||
self.youtube_video_url = youtube_video_url
|
self.add(youtube_video_url)
|
||||||
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
|
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
|
||||||
self.args_schema = FixedYoutubeVideoSearchToolSchema
|
self.args_schema = FixedYoutubeVideoSearchToolSchema
|
||||||
self._generate_description()
|
|
||||||
|
|
||||||
def _run(
|
def add(
|
||||||
self,
|
self,
|
||||||
search_query: str,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> None:
|
||||||
youtube_video_url = kwargs.get('youtube_video_url', self.youtube_video_url)
|
kwargs["data_type"] = DataType.YOUTUBE_VIDEO
|
||||||
self.app = App()
|
super().add(*args, **kwargs)
|
||||||
self.app.add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
|
|
||||||
return super()._run(query=search_query)
|
def _before_run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if "youtube_video_url" in kwargs:
|
||||||
|
self.add(kwargs["youtube_video_url"])
|
||||||
|
|||||||
43
tests/tools/rag/rag_tool_test.py
Normal file
43
tests/tools/rag/rag_tool_test.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from typing import cast
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from pytest import fixture
|
||||||
|
|
||||||
|
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||||
|
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||||
|
|
||||||
|
|
||||||
|
@fixture(autouse=True)
|
||||||
|
def mock_embedchain_db_uri():
|
||||||
|
with NamedTemporaryFile() as tmp:
|
||||||
|
uri = f"sqlite:///{tmp.name}"
|
||||||
|
with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_llm_and_embedder():
|
||||||
|
class MyTool(RagTool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = MyTool(
|
||||||
|
config=dict(
|
||||||
|
llm=dict(
|
||||||
|
provider="openai",
|
||||||
|
config=dict(model="gpt-3.5-custom"),
|
||||||
|
),
|
||||||
|
embedder=dict(
|
||||||
|
provider="openai",
|
||||||
|
config=dict(model="text-embedding-3-custom"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert tool.adapter is not None
|
||||||
|
assert isinstance(tool.adapter, EmbedchainAdapter)
|
||||||
|
|
||||||
|
adapter = cast(EmbedchainAdapter, tool.adapter)
|
||||||
|
assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom"
|
||||||
|
assert (
|
||||||
|
adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user