First take on a rag tool

This commit is contained in:
Gui Vieira
2024-02-13 20:10:56 -03:00
parent 54e4554f49
commit c1182eb322
22 changed files with 454 additions and 0 deletions

101
README.md
View File

@@ -0,0 +1,101 @@
## Getting started
When setting up agents you can provide tools for them to use. Here you will find ready-to-use tools as well as simple helpers for you to create your own tools.
In order to create a new tool, you have to pick one of the available strategies.
### Subclassing `BaseTool`
```python
class MyTool(BaseTool):
name: str = "Knowledge base"
description: str = "A knowledge base with all the requirements for the project."
def _run(self, question) -> str:
return (
tbl.search(embed_func([question])[0]).limit(3).to_pandas()["text"].tolist()
)
```
As you can see, all you need to do is to create a new class that inherits from `BaseTool`, define `name` and `description` fields, as well as implement the `_run` method.
### Create tool from a function or lambda
```python
my_tool = Tool(
name="Knowledge base",
description="A knowledge base with all the requirements for the project.",
func=lambda question: tbl.search(embed_func([question])[0])
.limit(3)
.to_pandas()["text"]
.tolist(),
)
```
Here's it's a bit simpler, as you don't have to subclass. Simply create a `Tool` object with the three required fields and you are good to go.
### Use the `tool` decorator.
```python
@tool("Knowledge base")
def my_tool(question: str) -> str:
"""A knowledge base with all the requirements for the project."""
return tbl.search(embed_func([question])[0]).limit(3).to_pandas()["text"].tolist()
```
By using the decorator you can easily wrap simple functions as tools. If you don't provide a name, the function name is going to be used. However, the docstring is required.
If you are using a linter you may see issues when passing your decorated tool in `tools` parameters that expect a list of `BaseTool`. If that's the case, you can use the `as_tool` helper.
## Contribution
This repo is open-source and we welcome contributions. If you're looking to contribute, please:
- Fork the repository.
- Create a new branch for your feature.
- Add your feature or improvement.
- Send a pull request.
- We appreciate your input!
### Installing Dependencies
```bash
poetry install
```
### Virtual Env
```bash
poetry shell
```
### Pre-commit hooks
```bash
pre-commit install
```
### Running Tests
```bash
poetry run pytest
```
### Running static type checks
```bash
poetry run pyright
```
### Packaging
```bash
poetry build
```
### Installing Locally
```bash
pip install dist/*.tar.gz
```

View File

@@ -0,0 +1 @@
from .base_tool import BaseTool, Tool, as_tool, tool

View File

@@ -0,0 +1,14 @@
from embedchain import App
from crewai_tools.rag_tool import Adapter
class EmbedchainAdapter(Adapter):
embedchain_app: App
dry_run: bool = False
def query(self, question: str) -> str:
result = self.embedchain_app.query(question, dry_run=self.dry_run)
if result is list:
return "\n".join(result)
return str(result)

View File

@@ -0,0 +1,49 @@
from pathlib import Path
from typing import Any, Callable
from lancedb import DBConnection as LanceDBConnection
from lancedb import connect as lancedb_connect
from lancedb.table import Table as LanceDBTable
from openai import Client as OpenAIClient
from pydantic import Field, PrivateAttr
from crewai_tools.rag_tool import Adapter
def _default_embedding_function():
client = OpenAIClient()
def _embedding_function(input):
rs = client.embeddings.create(input=input, model="text-embedding-ada-002")
return [record.embedding for record in rs.data]
return _embedding_function
class LanceDBAdapter(Adapter):
uri: str | Path
table_name: str
embedding_function: Callable = Field(default_factory=_default_embedding_function)
top_k: int = 3
vector_column_name: str = "vector"
text_column_name: str = "text"
_db: LanceDBConnection = PrivateAttr()
_table: LanceDBTable = PrivateAttr()
def model_post_init(self, __context: Any) -> None:
self._db = lancedb_connect(self.uri)
self._table = self._db.open_table(self.table_name)
return super().model_post_init(__context)
def query(self, question: str) -> str:
query = self.embedding_function([question])[0]
results = (
self._table.search(query, vector_column_name=self.vector_column_name)
.limit(self.top_k)
.select([self.text_column_name])
.to_list()
)
values = [result[self.text_column_name] for result in results]
return "\n".join(values)

View File

@@ -0,0 +1,82 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, cast
from langchain.agents import tools as langchain_tools
from pydantic import BaseModel
class BaseTool(BaseModel, ABC):
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool."""
def run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
return self._run(*args, **kwargs)
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Here goes the actual implementation of the tool."""
def to_langchain(self) -> langchain_tools.Tool:
return langchain_tools.Tool(
name=self.name,
description=self.description,
func=self._run,
)
class Tool(BaseTool):
func: Callable
"""The function that will be executed when the tool is called."""
def _run(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)
def to_langchain(
tools: list[BaseTool | langchain_tools.BaseTool],
) -> list[langchain_tools.BaseTool]:
return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools]
def tool(*args):
"""
Decorator to create a tool from a function.
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(f: Callable) -> BaseTool:
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
return Tool(
name=tool_name,
description=f.__doc__,
func=f,
)
return _make_tool
if len(args) == 1 and callable(args[0]):
return _make_with_name(args[0].__name__)(args[0])
if len(args) == 1 and isinstance(args[0], str):
return _make_with_name(args[0])
raise ValueError("Invalid arguments")
def as_tool(f: Any) -> BaseTool:
"""
Useful for when you create a tool using the @tool decorator and want to use it as a BaseTool.
It is a BaseTool, but type inference doesn't know that.
"""
assert isinstance(f, BaseTool)
return cast(BaseTool, f)

View File

@@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, ConfigDict
from crewai_tools.base_tool import BaseTool
class Adapter(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
def query(self, question: str) -> str:
"""Query the knowledge base with a question and return the answer."""
class RagTool(BaseTool):
name: str = "Knowledge base"
description: str = "A knowledge base that can be used to answer questions."
adapter: Adapter
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
return self.adapter.query(args[0])
def from_file(self, file_path: str):
from embedchain import App
from embedchain.models.data_type import DataType
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App()
app.add(file_path, data_type=DataType.TEXT_FILE)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(adapter=adapter)
def from_directory(self, directory_path: str):
from embedchain import App
from embedchain.loaders.directory_loader import DirectoryLoader
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
loader = DirectoryLoader(config=dict(recursive=True))
app = App()
app.add(directory_path, loader=loader)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(adapter=adapter)
def from_web_page(self, url: str):
from embedchain import App
from embedchain.models.data_type import DataType
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App()
app.add(url, data_type=DataType.WEB_PAGE)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(adapter=adapter)
def from_embedchain(self, config_path: str):
from embedchain import App
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App.from_config(config_path=config_path)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(adapter=adapter)

View File

@@ -0,0 +1,67 @@
from typing import Callable
from chromadb import Documents, EmbeddingFunction, Embeddings
from embedchain import App
from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.chroma import ChromaDB
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
class MockEmbeddingFunction(EmbeddingFunction):
fn: Callable
def __init__(self, embedding_fn: Callable):
self.fn = embedding_fn
def __call__(self, input: Documents) -> Embeddings:
return self.fn(input)
def test_embedchain_adapter(helpers):
embedding_function = MockEmbeddingFunction(
embedding_fn=helpers.get_embedding_function()
)
embedder = BaseEmbedder()
embedder.set_embedding_fn(embedding_function) # type: ignore
db = ChromaDB(
config=ChromaDbConfig(
dir="tests/data/chromadb",
collection_name="requirements",
)
)
app = App(
config=AppConfig(
id="test",
),
db=db,
embedding_model=embedder,
)
adapter = EmbedchainAdapter(
dry_run=True,
embedchain_app=app,
)
assert (
adapter.query("What are the requirements for the task?")
== """
Use the following pieces of context to answer the query at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Technical requirements
The system should be able to process 1000 transactions per second. The code must be written in Ruby. | Problem
Currently, we are not able to find out palindromes in a given string. We need a solution to this problem. | Solution
We need a function that takes a string as input and returns true if the string is a palindrome, otherwise false.
Query: What are the requirements for the task?
Helpful Answer:
"""
)

View File

@@ -0,0 +1,22 @@
from crewai_tools.adapters.lancedb_adapter import LanceDBAdapter
def test_lancedb_adapter(helpers):
adapter = LanceDBAdapter(
uri="tests/data/lancedb",
table_name="requirements",
embedding_function=helpers.get_embedding_function(),
top_k=2,
vector_column_name="vector",
text_column_name="text",
)
assert (
adapter.query("What are the requirements for the task?")
== """Technical requirements
The system should be able to process 1000 transactions per second. The code must be written in Ruby.
Problem
Currently, we are not able to find out palindromes in a given string. We need a solution to this problem."""
)

21
tests/conftest.py Normal file
View File

@@ -0,0 +1,21 @@
from typing import Callable
import pytest
class Helpers:
@staticmethod
def get_embedding_function() -> Callable:
def _func(input):
assert input == ["What are the requirements for the task?"]
with open("tests/data/embedding.txt", "r") as file:
content = file.read()
numbers = content.split(",")
return [[float(number) for number in numbers]]
return _func
@pytest.fixture
def helpers():
return Helpers

Binary file not shown.

1
tests/data/embedding.txt Normal file

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1 @@
$d2c46569-d173-4b3f-b589-f8f00eddc371<37>Vtext <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*string085vector <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*fixed_size_list:float:153608

21
tests/rag_tool_test.py Normal file
View File

@@ -0,0 +1,21 @@
from crewai_tools.rag_tool import Adapter, RagTool
class MockAdapter(Adapter):
answer: str
def query(self, question: str) -> str:
return self.answer
def test_rag_tool():
adapter = MockAdapter(answer="42")
rag_tool = RagTool(adapter=adapter)
assert rag_tool.name == "Knowledge base"
assert (
rag_tool.description == "A knowledge base that can be used to answer questions."
)
assert (
rag_tool.run("What is the answer to life, the universe and everything?") == "42"
)