mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 17:48:13 +00:00
Custom model config for RAG tools
This commit is contained in:
@@ -1,38 +1,71 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic.v1 import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from crewai_tools.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def query(self, question: str) -> str:
|
||||
"""Query the knowledge base with a question and return the answer."""
|
||||
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add content to the knowledge base."""
|
||||
|
||||
|
||||
class RagTool(BaseTool):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
class _AdapterPlaceholder(Adapter):
|
||||
def query(self, question: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
name: str = "Knowledge base"
|
||||
description: str = "A knowledge base that can be used to answer questions."
|
||||
summarize: bool = False
|
||||
adapter: Optional[Adapter] = None
|
||||
app: Optional[Any] = None
|
||||
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_default_adapter(self):
|
||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||
from embedchain import App
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
app = App.from_config(config=self.config) if self.config else App()
|
||||
self.adapter = EmbedchainAdapter(
|
||||
embedchain_app=app, summarize=self.summarize
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.adapter.add(*args, **kwargs)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
self.adapter = EmbedchainAdapter(embedchain_app=self.app, summarize=self.summarize)
|
||||
self._before_run(query, **kwargs)
|
||||
|
||||
return f"Relevant Content:\n{self.adapter.query(query)}"
|
||||
|
||||
def from_embedchain(self, config_path: str):
|
||||
from embedchain import App
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
app = App.from_config(config_path=config_path)
|
||||
adapter = EmbedchainAdapter(embedchain_app=app)
|
||||
return RagTool(name=self.name, description=self.description, adapter=adapter)
|
||||
def _before_run(self, query, **kwargs):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user