From a95f5c27c68fa846139ae81ec9478a4b9f91c553 Mon Sep 17 00:00:00 2001 From: Seth Donaldson Date: Wed, 26 Jun 2024 15:52:54 -0400 Subject: [PATCH] Create PDFEmbedchainAdapter class and utilize it in PDFSearchTool --- .../adapters/pdf_embedchain_adapter.py | 13 ++++++++++--- .../tools/pdf_search_tool/pdf_search_tool.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/crewai_tools/adapters/pdf_embedchain_adapter.py b/src/crewai_tools/adapters/pdf_embedchain_adapter.py index 446aab96c..12557c971 100644 --- a/src/crewai_tools/adapters/pdf_embedchain_adapter.py +++ b/src/crewai_tools/adapters/pdf_embedchain_adapter.py @@ -1,17 +1,23 @@ -from typing import Any +from typing import Any, Optional from embedchain import App from crewai_tools.tools.rag.rag_tool import Adapter -class EmbedchainAdapter(Adapter): +class PDFEmbedchainAdapter(Adapter): embedchain_app: App summarize: bool = False + src: Optional[str] = None def query(self, question: str) -> str: + where = ( + {"app_id": self.embedchain_app.config.id, "source": self.src} + if self.src + else None + ) result, sources = self.embedchain_app.query( - question, citations=True, dry_run=(not self.summarize) + question, citations=True, dry_run=(not self.summarize), where=where ) if self.summarize: return result @@ -22,4 +28,5 @@ class EmbedchainAdapter(Adapter): *args: Any, **kwargs: Any, ) -> None: + self.src = args[0] if args else None self.embedchain_app.add(*args, **kwargs) diff --git a/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py b/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py index af95ae0bf..48df8e966 100644 --- a/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py +++ b/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py @@ -1,6 +1,7 @@ from typing import Any, Optional, Type from embedchain.models.data_type import DataType +from pydantic import model_validator from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool @@ -35,6 +36,22 @@ class PDFSearchTool(RagTool): self.args_schema = FixedPDFSearchToolSchema self._generate_description() + @model_validator(mode="after") + def _set_default_adapter(self): + if isinstance(self.adapter, RagTool._AdapterPlaceholder): + from embedchain import App + + from crewai_tools.adapters.pdf_embedchain_adapter import ( + PDFEmbedchainAdapter, + ) + + app = App.from_config(config=self.config) if self.config else App() + self.adapter = PDFEmbedchainAdapter( + embedchain_app=app, summarize=self.summarize + ) + + return self + def add( self, *args: Any,