Custom model config for RAG tools

This commit is contained in:
Gui Vieira
2024-03-19 18:47:13 -03:00
parent 73cae1997d
commit 1c8d010601
20 changed files with 704 additions and 452 deletions

View File

@@ -48,9 +48,6 @@ rag_tool = RagTool().from_directory('path/to/your/directory')
# Example: Loading from a web page
rag_tool = RagTool().from_web_page('https://example.com')
# Example: Loading from an Embedchain configuration
rag_tool = RagTool().from_embedchain('path/to/your/config.json')
```
## **Contribution**
@@ -61,4 +58,4 @@ Contributions to RagTool and the broader CrewAI tools ecosystem are welcome. To
RagTool is open-source and available under the MIT license.
Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better.
Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better.

View File

@@ -1,38 +1,71 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any
from pydantic.v1 import BaseModel, ConfigDict
from pydantic import BaseModel, Field, model_validator
from crewai_tools.tools.base_tool import BaseTool
class Adapter(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
class Config:
arbitrary_types_allowed = True
@abstractmethod
def query(self, question: str) -> str:
"""Query the knowledge base with a question and return the answer."""
@abstractmethod
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
"""Add content to the knowledge base."""
class RagTool(BaseTool):
model_config = ConfigDict(arbitrary_types_allowed=True)
class _AdapterPlaceholder(Adapter):
def query(self, question: str) -> str:
raise NotImplementedError
def add(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError
name: str = "Knowledge base"
description: str = "A knowledge base that can be used to answer questions."
summarize: bool = False
adapter: Optional[Adapter] = None
app: Optional[Any] = None
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
config: dict[str, Any] | None = None
@model_validator(mode="after")
def _set_default_adapter(self):
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
from embedchain import App
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App.from_config(config=self.config) if self.config else App()
self.adapter = EmbedchainAdapter(
embedchain_app=app, summarize=self.summarize
)
return self
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.adapter.add(*args, **kwargs)
def _run(
self,
query: str,
**kwargs: Any,
) -> Any:
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
self.adapter = EmbedchainAdapter(embedchain_app=self.app, summarize=self.summarize)
self._before_run(query, **kwargs)
return f"Relevant Content:\n{self.adapter.query(query)}"
def from_embedchain(self, config_path: str):
from embedchain import App
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App.from_config(config_path=config_path)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(name=self.name, description=self.description, adapter=adapter)
def _before_run(self, query, **kwargs):
pass