Custom model config for RAG tools

This commit is contained in:
Gui Vieira
2024-03-19 18:47:13 -03:00
parent 73cae1997d
commit 1c8d010601
20 changed files with 704 additions and 452 deletions

View File

@@ -1,38 +1,71 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any
from pydantic.v1 import BaseModel, ConfigDict
from pydantic import BaseModel, Field, model_validator
from crewai_tools.tools.base_tool import BaseTool
class Adapter(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
class Config:
arbitrary_types_allowed = True
@abstractmethod
def query(self, question: str) -> str:
"""Query the knowledge base with a question and return the answer."""
@abstractmethod
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
"""Add content to the knowledge base."""
class RagTool(BaseTool):
model_config = ConfigDict(arbitrary_types_allowed=True)
class _AdapterPlaceholder(Adapter):
def query(self, question: str) -> str:
raise NotImplementedError
def add(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError
name: str = "Knowledge base"
description: str = "A knowledge base that can be used to answer questions."
summarize: bool = False
adapter: Optional[Adapter] = None
app: Optional[Any] = None
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
config: dict[str, Any] | None = None
@model_validator(mode="after")
def _set_default_adapter(self):
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
from embedchain import App
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App.from_config(config=self.config) if self.config else App()
self.adapter = EmbedchainAdapter(
embedchain_app=app, summarize=self.summarize
)
return self
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.adapter.add(*args, **kwargs)
def _run(
self,
query: str,
**kwargs: Any,
) -> Any:
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
self.adapter = EmbedchainAdapter(embedchain_app=self.app, summarize=self.summarize)
self._before_run(query, **kwargs)
return f"Relevant Content:\n{self.adapter.query(query)}"
def from_embedchain(self, config_path: str):
from embedchain import App
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App.from_config(config_path=config_path)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(name=self.name, description=self.description, adapter=adapter)
def _before_run(self, query, **kwargs):
pass