mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 15:52:34 +00:00
Custom model config for RAG tools
This commit is contained in:
@@ -1,12 +1,25 @@
|
||||
from typing import Any
|
||||
|
||||
from embedchain import App
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
|
||||
class EmbedchainAdapter(Adapter):
|
||||
embedchain_app: Any
|
||||
embedchain_app: App
|
||||
summarize: bool = False
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
result, sources = self.embedchain_app.query(question, citations=True, dry_run=(not self.summarize))
|
||||
result, sources = self.embedchain_app.query(
|
||||
question, citations=True, dry_run=(not self.summarize)
|
||||
)
|
||||
if self.summarize:
|
||||
return result
|
||||
return "\n\n".join([source[0] for source in sources])
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.embedchain_app.add(*args, **kwargs)
|
||||
|
||||
@@ -35,7 +35,7 @@ class LanceDBAdapter(Adapter):
|
||||
self._db = lancedb_connect(self.uri)
|
||||
self._table = self._db.open_table(self.table_name)
|
||||
|
||||
return super().model_post_init(__context)
|
||||
super().model_post_init(__context)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
query = self.embedding_function([question])[0]
|
||||
@@ -47,3 +47,10 @@ class LanceDBAdapter(Adapter):
|
||||
)
|
||||
values = [result[self.text_column_name] for result in results]
|
||||
return "\n".join(values)
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._table.add(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user