mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
First take on a rag tool
This commit is contained in:
101
README.md
101
README.md
@@ -0,0 +1,101 @@
|
||||
## Getting started
|
||||
|
||||
When setting up agents you can provide tools for them to use. Here you will find ready-to-use tools as well as simple helpers for you to create your own tools.
|
||||
|
||||
In order to create a new tool, you have to pick one of the available strategies.
|
||||
|
||||
### Subclassing `BaseTool`
|
||||
|
||||
```python
|
||||
class MyTool(BaseTool):
|
||||
name: str = "Knowledge base"
|
||||
description: str = "A knowledge base with all the requirements for the project."
|
||||
|
||||
def _run(self, question) -> str:
|
||||
return (
|
||||
tbl.search(embed_func([question])[0]).limit(3).to_pandas()["text"].tolist()
|
||||
)
|
||||
```
|
||||
|
||||
As you can see, all you need to do is to create a new class that inherits from `BaseTool`, define `name` and `description` fields, as well as implement the `_run` method.
|
||||
|
||||
### Create tool from a function or lambda
|
||||
|
||||
```python
|
||||
my_tool = Tool(
|
||||
name="Knowledge base",
|
||||
description="A knowledge base with all the requirements for the project.",
|
||||
func=lambda question: tbl.search(embed_func([question])[0])
|
||||
.limit(3)
|
||||
.to_pandas()["text"]
|
||||
.tolist(),
|
||||
)
|
||||
```
|
||||
|
||||
Here's it's a bit simpler, as you don't have to subclass. Simply create a `Tool` object with the three required fields and you are good to go.
|
||||
|
||||
### Use the `tool` decorator.
|
||||
|
||||
```python
|
||||
@tool("Knowledge base")
|
||||
def my_tool(question: str) -> str:
|
||||
"""A knowledge base with all the requirements for the project."""
|
||||
return tbl.search(embed_func([question])[0]).limit(3).to_pandas()["text"].tolist()
|
||||
```
|
||||
|
||||
By using the decorator you can easily wrap simple functions as tools. If you don't provide a name, the function name is going to be used. However, the docstring is required.
|
||||
|
||||
If you are using a linter you may see issues when passing your decorated tool in `tools` parameters that expect a list of `BaseTool`. If that's the case, you can use the `as_tool` helper.
|
||||
|
||||
|
||||
## Contribution
|
||||
|
||||
This repo is open-source and we welcome contributions. If you're looking to contribute, please:
|
||||
|
||||
- Fork the repository.
|
||||
- Create a new branch for your feature.
|
||||
- Add your feature or improvement.
|
||||
- Send a pull request.
|
||||
- We appreciate your input!
|
||||
|
||||
### Installing Dependencies
|
||||
|
||||
```bash
|
||||
poetry install
|
||||
```
|
||||
|
||||
### Virtual Env
|
||||
|
||||
```bash
|
||||
poetry shell
|
||||
```
|
||||
|
||||
### Pre-commit hooks
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
poetry run pytest
|
||||
```
|
||||
|
||||
### Running static type checks
|
||||
|
||||
```bash
|
||||
poetry run pyright
|
||||
```
|
||||
|
||||
### Packaging
|
||||
|
||||
```bash
|
||||
poetry build
|
||||
```
|
||||
|
||||
### Installing Locally
|
||||
|
||||
```bash
|
||||
pip install dist/*.tar.gz
|
||||
```
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .base_tool import BaseTool, Tool, as_tool, tool
|
||||
|
||||
14
src/crewai_tools/adapters/embedchain_adapter.py
Normal file
14
src/crewai_tools/adapters/embedchain_adapter.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from embedchain import App
|
||||
|
||||
from crewai_tools.rag_tool import Adapter
|
||||
|
||||
|
||||
class EmbedchainAdapter(Adapter):
|
||||
embedchain_app: App
|
||||
dry_run: bool = False
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
result = self.embedchain_app.query(question, dry_run=self.dry_run)
|
||||
if result is list:
|
||||
return "\n".join(result)
|
||||
return str(result)
|
||||
49
src/crewai_tools/adapters/lancedb_adapter.py
Normal file
49
src/crewai_tools/adapters/lancedb_adapter.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from lancedb import DBConnection as LanceDBConnection
|
||||
from lancedb import connect as lancedb_connect
|
||||
from lancedb.table import Table as LanceDBTable
|
||||
from openai import Client as OpenAIClient
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from crewai_tools.rag_tool import Adapter
|
||||
|
||||
|
||||
def _default_embedding_function():
|
||||
client = OpenAIClient()
|
||||
|
||||
def _embedding_function(input):
|
||||
rs = client.embeddings.create(input=input, model="text-embedding-ada-002")
|
||||
return [record.embedding for record in rs.data]
|
||||
|
||||
return _embedding_function
|
||||
|
||||
|
||||
class LanceDBAdapter(Adapter):
|
||||
uri: str | Path
|
||||
table_name: str
|
||||
embedding_function: Callable = Field(default_factory=_default_embedding_function)
|
||||
top_k: int = 3
|
||||
vector_column_name: str = "vector"
|
||||
text_column_name: str = "text"
|
||||
|
||||
_db: LanceDBConnection = PrivateAttr()
|
||||
_table: LanceDBTable = PrivateAttr()
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._db = lancedb_connect(self.uri)
|
||||
self._table = self._db.open_table(self.table_name)
|
||||
|
||||
return super().model_post_init(__context)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
query = self.embedding_function([question])[0]
|
||||
results = (
|
||||
self._table.search(query, vector_column_name=self.vector_column_name)
|
||||
.limit(self.top_k)
|
||||
.select([self.text_column_name])
|
||||
.to_list()
|
||||
)
|
||||
values = [result[self.text_column_name] for result in results]
|
||||
return "\n".join(values)
|
||||
82
src/crewai_tools/base_tool.py
Normal file
82
src/crewai_tools/base_tool.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from langchain.agents import tools as langchain_tools
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Used to tell the model how/when/why to use the tool."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self._run(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Here goes the actual implementation of the tool."""
|
||||
|
||||
def to_langchain(self) -> langchain_tools.Tool:
|
||||
return langchain_tools.Tool(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
func=self._run,
|
||||
)
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
func: Callable
|
||||
"""The function that will be executed when the tool is called."""
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def to_langchain(
|
||||
tools: list[BaseTool | langchain_tools.BaseTool],
|
||||
) -> list[langchain_tools.BaseTool]:
|
||||
return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools]
|
||||
|
||||
|
||||
def tool(*args):
|
||||
"""
|
||||
Decorator to create a tool from a function.
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(f: Callable) -> BaseTool:
|
||||
if f.__doc__ is None:
|
||||
raise ValueError("Function must have a docstring")
|
||||
|
||||
return Tool(
|
||||
name=tool_name,
|
||||
description=f.__doc__,
|
||||
func=f,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
|
||||
if len(args) == 1 and callable(args[0]):
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
if len(args) == 1 and isinstance(args[0], str):
|
||||
return _make_with_name(args[0])
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
def as_tool(f: Any) -> BaseTool:
|
||||
"""
|
||||
Useful for when you create a tool using the @tool decorator and want to use it as a BaseTool.
|
||||
It is a BaseTool, but type inference doesn't know that.
|
||||
"""
|
||||
assert isinstance(f, BaseTool)
|
||||
return cast(BaseTool, f)
|
||||
74
src/crewai_tools/rag_tool.py
Normal file
74
src/crewai_tools/rag_tool.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from crewai_tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def query(self, question: str) -> str:
|
||||
"""Query the knowledge base with a question and return the answer."""
|
||||
|
||||
|
||||
class RagTool(BaseTool):
|
||||
name: str = "Knowledge base"
|
||||
description: str = "A knowledge base that can be used to answer questions."
|
||||
adapter: Adapter
|
||||
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self.adapter.query(args[0])
|
||||
|
||||
def from_file(self, file_path: str):
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
app = App()
|
||||
app.add(file_path, data_type=DataType.TEXT_FILE)
|
||||
|
||||
adapter = EmbedchainAdapter(embedchain_app=app)
|
||||
return RagTool(adapter=adapter)
|
||||
|
||||
def from_directory(self, directory_path: str):
|
||||
from embedchain import App
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
loader = DirectoryLoader(config=dict(recursive=True))
|
||||
|
||||
app = App()
|
||||
app.add(directory_path, loader=loader)
|
||||
|
||||
adapter = EmbedchainAdapter(embedchain_app=app)
|
||||
return RagTool(adapter=adapter)
|
||||
|
||||
def from_web_page(self, url: str):
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
app = App()
|
||||
app.add(url, data_type=DataType.WEB_PAGE)
|
||||
|
||||
adapter = EmbedchainAdapter(embedchain_app=app)
|
||||
return RagTool(adapter=adapter)
|
||||
|
||||
def from_embedchain(self, config_path: str):
|
||||
from embedchain import App
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
app = App.from_config(config_path=config_path)
|
||||
adapter = EmbedchainAdapter(embedchain_app=app)
|
||||
return RagTool(adapter=adapter)
|
||||
67
tests/adapters/embedchain_adapter_test.py
Normal file
67
tests/adapters/embedchain_adapter_test.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import Callable
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
|
||||
class MockEmbeddingFunction(EmbeddingFunction):
|
||||
fn: Callable
|
||||
|
||||
def __init__(self, embedding_fn: Callable):
|
||||
self.fn = embedding_fn
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
return self.fn(input)
|
||||
|
||||
|
||||
def test_embedchain_adapter(helpers):
|
||||
embedding_function = MockEmbeddingFunction(
|
||||
embedding_fn=helpers.get_embedding_function()
|
||||
)
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_embedding_fn(embedding_function) # type: ignore
|
||||
|
||||
db = ChromaDB(
|
||||
config=ChromaDbConfig(
|
||||
dir="tests/data/chromadb",
|
||||
collection_name="requirements",
|
||||
)
|
||||
)
|
||||
|
||||
app = App(
|
||||
config=AppConfig(
|
||||
id="test",
|
||||
),
|
||||
db=db,
|
||||
embedding_model=embedder,
|
||||
)
|
||||
|
||||
adapter = EmbedchainAdapter(
|
||||
dry_run=True,
|
||||
embedchain_app=app,
|
||||
)
|
||||
|
||||
assert (
|
||||
adapter.query("What are the requirements for the task?")
|
||||
== """
|
||||
Use the following pieces of context to answer the query at the end.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
Technical requirements
|
||||
|
||||
The system should be able to process 1000 transactions per second. The code must be written in Ruby. | Problem
|
||||
|
||||
Currently, we are not able to find out palindromes in a given string. We need a solution to this problem. | Solution
|
||||
|
||||
We need a function that takes a string as input and returns true if the string is a palindrome, otherwise false.
|
||||
|
||||
Query: What are the requirements for the task?
|
||||
|
||||
Helpful Answer:
|
||||
"""
|
||||
)
|
||||
22
tests/adapters/lancedb_adapter_test.py
Normal file
22
tests/adapters/lancedb_adapter_test.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from crewai_tools.adapters.lancedb_adapter import LanceDBAdapter
|
||||
|
||||
|
||||
def test_lancedb_adapter(helpers):
|
||||
adapter = LanceDBAdapter(
|
||||
uri="tests/data/lancedb",
|
||||
table_name="requirements",
|
||||
embedding_function=helpers.get_embedding_function(),
|
||||
top_k=2,
|
||||
vector_column_name="vector",
|
||||
text_column_name="text",
|
||||
)
|
||||
|
||||
assert (
|
||||
adapter.query("What are the requirements for the task?")
|
||||
== """Technical requirements
|
||||
|
||||
The system should be able to process 1000 transactions per second. The code must be written in Ruby.
|
||||
Problem
|
||||
|
||||
Currently, we are not able to find out palindromes in a given string. We need a solution to this problem."""
|
||||
)
|
||||
21
tests/conftest.py
Normal file
21
tests/conftest.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class Helpers:
|
||||
@staticmethod
|
||||
def get_embedding_function() -> Callable:
|
||||
def _func(input):
|
||||
assert input == ["What are the requirements for the task?"]
|
||||
with open("tests/data/embedding.txt", "r") as file:
|
||||
content = file.read()
|
||||
numbers = content.split(",")
|
||||
return [[float(number) for number in numbers]]
|
||||
|
||||
return _func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def helpers():
|
||||
return Helpers
|
||||
BIN
tests/data/chromadb/chroma.sqlite3
Normal file
BIN
tests/data/chromadb/chroma.sqlite3
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1
tests/data/embedding.txt
Normal file
1
tests/data/embedding.txt
Normal file
File diff suppressed because one or more lines are too long
BIN
tests/data/lancedb/requirements.lance/_latest.manifest
Normal file
BIN
tests/data/lancedb/requirements.lance/_latest.manifest
Normal file
Binary file not shown.
@@ -0,0 +1 @@
|
||||
$d2c46569-d173-4b3f-b589-f8f00eddc371<37>Vtext <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*string085vector <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*fixed_size_list:float:153608
|
||||
Binary file not shown.
BIN
tests/data/lancedb/requirements.lance/_versions/1.manifest
Normal file
BIN
tests/data/lancedb/requirements.lance/_versions/1.manifest
Normal file
Binary file not shown.
BIN
tests/data/lancedb/requirements.lance/_versions/2.manifest
Normal file
BIN
tests/data/lancedb/requirements.lance/_versions/2.manifest
Normal file
Binary file not shown.
Binary file not shown.
21
tests/rag_tool_test.py
Normal file
21
tests/rag_tool_test.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from crewai_tools.rag_tool import Adapter, RagTool
|
||||
|
||||
|
||||
class MockAdapter(Adapter):
|
||||
answer: str
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
return self.answer
|
||||
|
||||
|
||||
def test_rag_tool():
|
||||
adapter = MockAdapter(answer="42")
|
||||
rag_tool = RagTool(adapter=adapter)
|
||||
|
||||
assert rag_tool.name == "Knowledge base"
|
||||
assert (
|
||||
rag_tool.description == "A knowledge base that can be used to answer questions."
|
||||
)
|
||||
assert (
|
||||
rag_tool.run("What is the answer to life, the universe and everything?") == "42"
|
||||
)
|
||||
Reference in New Issue
Block a user