mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 22:08:21 +00:00
First take on a rag tool
This commit is contained in:
67
tests/adapters/embedchain_adapter_test.py
Normal file
67
tests/adapters/embedchain_adapter_test.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import Callable
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
|
||||
class MockEmbeddingFunction(EmbeddingFunction):
|
||||
fn: Callable
|
||||
|
||||
def __init__(self, embedding_fn: Callable):
|
||||
self.fn = embedding_fn
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
return self.fn(input)
|
||||
|
||||
|
||||
def test_embedchain_adapter(helpers):
|
||||
embedding_function = MockEmbeddingFunction(
|
||||
embedding_fn=helpers.get_embedding_function()
|
||||
)
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_embedding_fn(embedding_function) # type: ignore
|
||||
|
||||
db = ChromaDB(
|
||||
config=ChromaDbConfig(
|
||||
dir="tests/data/chromadb",
|
||||
collection_name="requirements",
|
||||
)
|
||||
)
|
||||
|
||||
app = App(
|
||||
config=AppConfig(
|
||||
id="test",
|
||||
),
|
||||
db=db,
|
||||
embedding_model=embedder,
|
||||
)
|
||||
|
||||
adapter = EmbedchainAdapter(
|
||||
dry_run=True,
|
||||
embedchain_app=app,
|
||||
)
|
||||
|
||||
assert (
|
||||
adapter.query("What are the requirements for the task?")
|
||||
== """
|
||||
Use the following pieces of context to answer the query at the end.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
Technical requirements
|
||||
|
||||
The system should be able to process 1000 transactions per second. The code must be written in Ruby. | Problem
|
||||
|
||||
Currently, we are not able to find out palindromes in a given string. We need a solution to this problem. | Solution
|
||||
|
||||
We need a function that takes a string as input and returns true if the string is a palindrome, otherwise false.
|
||||
|
||||
Query: What are the requirements for the task?
|
||||
|
||||
Helpful Answer:
|
||||
"""
|
||||
)
|
||||
22
tests/adapters/lancedb_adapter_test.py
Normal file
22
tests/adapters/lancedb_adapter_test.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from crewai_tools.adapters.lancedb_adapter import LanceDBAdapter
|
||||
|
||||
|
||||
def test_lancedb_adapter(helpers):
|
||||
adapter = LanceDBAdapter(
|
||||
uri="tests/data/lancedb",
|
||||
table_name="requirements",
|
||||
embedding_function=helpers.get_embedding_function(),
|
||||
top_k=2,
|
||||
vector_column_name="vector",
|
||||
text_column_name="text",
|
||||
)
|
||||
|
||||
assert (
|
||||
adapter.query("What are the requirements for the task?")
|
||||
== """Technical requirements
|
||||
|
||||
The system should be able to process 1000 transactions per second. The code must be written in Ruby.
|
||||
Problem
|
||||
|
||||
Currently, we are not able to find out palindromes in a given string. We need a solution to this problem."""
|
||||
)
|
||||
21
tests/conftest.py
Normal file
21
tests/conftest.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class Helpers:
|
||||
@staticmethod
|
||||
def get_embedding_function() -> Callable:
|
||||
def _func(input):
|
||||
assert input == ["What are the requirements for the task?"]
|
||||
with open("tests/data/embedding.txt", "r") as file:
|
||||
content = file.read()
|
||||
numbers = content.split(",")
|
||||
return [[float(number) for number in numbers]]
|
||||
|
||||
return _func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def helpers():
|
||||
return Helpers
|
||||
BIN
tests/data/chromadb/chroma.sqlite3
Normal file
BIN
tests/data/chromadb/chroma.sqlite3
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1
tests/data/embedding.txt
Normal file
1
tests/data/embedding.txt
Normal file
File diff suppressed because one or more lines are too long
BIN
tests/data/lancedb/requirements.lance/_latest.manifest
Normal file
BIN
tests/data/lancedb/requirements.lance/_latest.manifest
Normal file
Binary file not shown.
@@ -0,0 +1 @@
|
||||
$d2c46569-d173-4b3f-b589-f8f00eddc371²Vtext ÿÿÿÿÿÿÿÿÿ*string085vector ÿÿÿÿÿÿÿÿÿ*fixed_size_list:float:153608
|
||||
Binary file not shown.
BIN
tests/data/lancedb/requirements.lance/_versions/1.manifest
Normal file
BIN
tests/data/lancedb/requirements.lance/_versions/1.manifest
Normal file
Binary file not shown.
BIN
tests/data/lancedb/requirements.lance/_versions/2.manifest
Normal file
BIN
tests/data/lancedb/requirements.lance/_versions/2.manifest
Normal file
Binary file not shown.
Binary file not shown.
21
tests/rag_tool_test.py
Normal file
21
tests/rag_tool_test.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from crewai_tools.rag_tool import Adapter, RagTool
|
||||
|
||||
|
||||
class MockAdapter(Adapter):
|
||||
answer: str
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
return self.answer
|
||||
|
||||
|
||||
def test_rag_tool():
|
||||
adapter = MockAdapter(answer="42")
|
||||
rag_tool = RagTool(adapter=adapter)
|
||||
|
||||
assert rag_tool.name == "Knowledge base"
|
||||
assert (
|
||||
rag_tool.description == "A knowledge base that can be used to answer questions."
|
||||
)
|
||||
assert (
|
||||
rag_tool.run("What is the answer to life, the universe and everything?") == "42"
|
||||
)
|
||||
Reference in New Issue
Block a user