Custom model config for RAG tools

This commit is contained in:
Gui Vieira
2024-03-19 18:47:13 -03:00
parent 73cae1997d
commit 1c8d010601
20 changed files with 704 additions and 452 deletions

View File

@@ -1,12 +1,25 @@
from typing import Any
from embedchain import App
from crewai_tools.tools.rag.rag_tool import Adapter
class EmbedchainAdapter(Adapter):
embedchain_app: Any
embedchain_app: App
summarize: bool = False
def query(self, question: str) -> str:
result, sources = self.embedchain_app.query(question, citations=True, dry_run=(not self.summarize))
result, sources = self.embedchain_app.query(
question, citations=True, dry_run=(not self.summarize)
)
if self.summarize:
return result
return "\n\n".join([source[0] for source in sources])
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.embedchain_app.add(*args, **kwargs)

View File

@@ -35,7 +35,7 @@ class LanceDBAdapter(Adapter):
self._db = lancedb_connect(self.uri)
self._table = self._db.open_table(self.table_name)
return super().model_post_init(__context)
super().model_post_init(__context)
def query(self, question: str) -> str:
query = self.embedding_function([question])[0]
@@ -47,3 +47,10 @@ class LanceDBAdapter(Adapter):
)
values = [result[self.text_column_name] for result in results]
return "\n".join(values)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self._table.add(*args, **kwargs)