diff --git a/src/crewai_tools/adapters/pdf_embedchain_adapter.py b/src/crewai_tools/adapters/pdf_embedchain_adapter.py new file mode 100644 index 000000000..12557c971 --- /dev/null +++ b/src/crewai_tools/adapters/pdf_embedchain_adapter.py @@ -0,0 +1,32 @@ +from typing import Any, Optional + +from embedchain import App + +from crewai_tools.tools.rag.rag_tool import Adapter + + +class PDFEmbedchainAdapter(Adapter): + embedchain_app: App + summarize: bool = False + src: Optional[str] = None + + def query(self, question: str) -> str: + where = ( + {"app_id": self.embedchain_app.config.id, "source": self.src} + if self.src + else None + ) + result, sources = self.embedchain_app.query( + question, citations=True, dry_run=(not self.summarize), where=where + ) + if self.summarize: + return result + return "\n\n".join([source[0] for source in sources]) + + def add( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.src = args[0] if args else None + self.embedchain_app.add(*args, **kwargs) diff --git a/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py b/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py index af95ae0bf..48df8e966 100644 --- a/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py +++ b/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py @@ -1,6 +1,7 @@ from typing import Any, Optional, Type from embedchain.models.data_type import DataType +from pydantic import model_validator from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool @@ -35,6 +36,22 @@ class PDFSearchTool(RagTool): self.args_schema = FixedPDFSearchToolSchema self._generate_description() + @model_validator(mode="after") + def _set_default_adapter(self): + if isinstance(self.adapter, RagTool._AdapterPlaceholder): + from embedchain import App + + from crewai_tools.adapters.pdf_embedchain_adapter import ( + PDFEmbedchainAdapter, + ) + + app = App.from_config(config=self.config) if self.config else App() + self.adapter = PDFEmbedchainAdapter( + embedchain_app=app, summarize=self.summarize + ) + + return self + def add( self, *args: Any,