mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
This commit adds a file lock to `RagTool`. As it uses Embedchain and Chroma internally, this tool was not process-safe.
67 lines
1.8 KiB
Python
67 lines
1.8 KiB
Python
import portalocker
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any
|
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
|
|
from crewai.tools import BaseTool
|
|
|
|
|
|
class Adapter(BaseModel, ABC):
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
@abstractmethod
|
|
def query(self, question: str) -> str:
|
|
"""Query the knowledge base with a question and return the answer."""
|
|
|
|
@abstractmethod
|
|
def add(
|
|
self,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Add content to the knowledge base."""
|
|
|
|
|
|
class RagTool(BaseTool):
|
|
class _AdapterPlaceholder(Adapter):
|
|
def query(self, question: str) -> str:
|
|
raise NotImplementedError
|
|
|
|
def add(self, *args: Any, **kwargs: Any) -> None:
|
|
raise NotImplementedError
|
|
|
|
name: str = "Knowledge base"
|
|
description: str = "A knowledge base that can be used to answer questions."
|
|
summarize: bool = False
|
|
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
|
config: dict[str, Any] | None = None
|
|
|
|
@model_validator(mode="after")
|
|
def _set_default_adapter(self):
|
|
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
|
from embedchain import App
|
|
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
|
|
|
with portalocker.Lock("crewai-rag-tool.lock", timeout=10):
|
|
app = App.from_config(config=self.config) if self.config else App()
|
|
|
|
self.adapter = EmbedchainAdapter(
|
|
embedchain_app=app, summarize=self.summarize
|
|
)
|
|
|
|
return self
|
|
|
|
def add(
|
|
self,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self.adapter.add(*args, **kwargs)
|
|
|
|
def _run(
|
|
self,
|
|
query: str,
|
|
) -> str:
|
|
return f"Relevant Content:\n{self.adapter.query(query)}"
|