First take on a rag tool

This commit is contained in:
Gui Vieira
2024-02-13 20:10:56 -03:00
parent 54e4554f49
commit c1182eb322
22 changed files with 454 additions and 0 deletions

View File

@@ -0,0 +1 @@
from .base_tool import BaseTool, Tool, as_tool, tool

View File

@@ -0,0 +1,14 @@
from embedchain import App
from crewai_tools.rag_tool import Adapter
class EmbedchainAdapter(Adapter):
embedchain_app: App
dry_run: bool = False
def query(self, question: str) -> str:
result = self.embedchain_app.query(question, dry_run=self.dry_run)
if result is list:
return "\n".join(result)
return str(result)

View File

@@ -0,0 +1,49 @@
from pathlib import Path
from typing import Any, Callable
from lancedb import DBConnection as LanceDBConnection
from lancedb import connect as lancedb_connect
from lancedb.table import Table as LanceDBTable
from openai import Client as OpenAIClient
from pydantic import Field, PrivateAttr
from crewai_tools.rag_tool import Adapter
def _default_embedding_function():
client = OpenAIClient()
def _embedding_function(input):
rs = client.embeddings.create(input=input, model="text-embedding-ada-002")
return [record.embedding for record in rs.data]
return _embedding_function
class LanceDBAdapter(Adapter):
uri: str | Path
table_name: str
embedding_function: Callable = Field(default_factory=_default_embedding_function)
top_k: int = 3
vector_column_name: str = "vector"
text_column_name: str = "text"
_db: LanceDBConnection = PrivateAttr()
_table: LanceDBTable = PrivateAttr()
def model_post_init(self, __context: Any) -> None:
self._db = lancedb_connect(self.uri)
self._table = self._db.open_table(self.table_name)
return super().model_post_init(__context)
def query(self, question: str) -> str:
query = self.embedding_function([question])[0]
results = (
self._table.search(query, vector_column_name=self.vector_column_name)
.limit(self.top_k)
.select([self.text_column_name])
.to_list()
)
values = [result[self.text_column_name] for result in results]
return "\n".join(values)

View File

@@ -0,0 +1,82 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, cast
from langchain.agents import tools as langchain_tools
from pydantic import BaseModel
class BaseTool(BaseModel, ABC):
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool."""
def run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
return self._run(*args, **kwargs)
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Here goes the actual implementation of the tool."""
def to_langchain(self) -> langchain_tools.Tool:
return langchain_tools.Tool(
name=self.name,
description=self.description,
func=self._run,
)
class Tool(BaseTool):
func: Callable
"""The function that will be executed when the tool is called."""
def _run(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)
def to_langchain(
tools: list[BaseTool | langchain_tools.BaseTool],
) -> list[langchain_tools.BaseTool]:
return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools]
def tool(*args):
"""
Decorator to create a tool from a function.
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(f: Callable) -> BaseTool:
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
return Tool(
name=tool_name,
description=f.__doc__,
func=f,
)
return _make_tool
if len(args) == 1 and callable(args[0]):
return _make_with_name(args[0].__name__)(args[0])
if len(args) == 1 and isinstance(args[0], str):
return _make_with_name(args[0])
raise ValueError("Invalid arguments")
def as_tool(f: Any) -> BaseTool:
"""
Useful for when you create a tool using the @tool decorator and want to use it as a BaseTool.
It is a BaseTool, but type inference doesn't know that.
"""
assert isinstance(f, BaseTool)
return cast(BaseTool, f)

View File

@@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, ConfigDict
from crewai_tools.base_tool import BaseTool
class Adapter(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
def query(self, question: str) -> str:
"""Query the knowledge base with a question and return the answer."""
class RagTool(BaseTool):
name: str = "Knowledge base"
description: str = "A knowledge base that can be used to answer questions."
adapter: Adapter
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
return self.adapter.query(args[0])
def from_file(self, file_path: str):
from embedchain import App
from embedchain.models.data_type import DataType
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App()
app.add(file_path, data_type=DataType.TEXT_FILE)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(adapter=adapter)
def from_directory(self, directory_path: str):
from embedchain import App
from embedchain.loaders.directory_loader import DirectoryLoader
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
loader = DirectoryLoader(config=dict(recursive=True))
app = App()
app.add(directory_path, loader=loader)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(adapter=adapter)
def from_web_page(self, url: str):
from embedchain import App
from embedchain.models.data_type import DataType
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App()
app.add(url, data_type=DataType.WEB_PAGE)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(adapter=adapter)
def from_embedchain(self, config_path: str):
from embedchain import App
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App.from_config(config_path=config_path)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(adapter=adapter)