mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-15 02:58:30 +00:00
First take on a rag tool
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from .base_tool import BaseTool, Tool, as_tool, tool
|
||||
|
||||
14
src/crewai_tools/adapters/embedchain_adapter.py
Normal file
14
src/crewai_tools/adapters/embedchain_adapter.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from embedchain import App
|
||||
|
||||
from crewai_tools.rag_tool import Adapter
|
||||
|
||||
|
||||
class EmbedchainAdapter(Adapter):
|
||||
embedchain_app: App
|
||||
dry_run: bool = False
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
result = self.embedchain_app.query(question, dry_run=self.dry_run)
|
||||
if result is list:
|
||||
return "\n".join(result)
|
||||
return str(result)
|
||||
49
src/crewai_tools/adapters/lancedb_adapter.py
Normal file
49
src/crewai_tools/adapters/lancedb_adapter.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from lancedb import DBConnection as LanceDBConnection
|
||||
from lancedb import connect as lancedb_connect
|
||||
from lancedb.table import Table as LanceDBTable
|
||||
from openai import Client as OpenAIClient
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from crewai_tools.rag_tool import Adapter
|
||||
|
||||
|
||||
def _default_embedding_function():
|
||||
client = OpenAIClient()
|
||||
|
||||
def _embedding_function(input):
|
||||
rs = client.embeddings.create(input=input, model="text-embedding-ada-002")
|
||||
return [record.embedding for record in rs.data]
|
||||
|
||||
return _embedding_function
|
||||
|
||||
|
||||
class LanceDBAdapter(Adapter):
|
||||
uri: str | Path
|
||||
table_name: str
|
||||
embedding_function: Callable = Field(default_factory=_default_embedding_function)
|
||||
top_k: int = 3
|
||||
vector_column_name: str = "vector"
|
||||
text_column_name: str = "text"
|
||||
|
||||
_db: LanceDBConnection = PrivateAttr()
|
||||
_table: LanceDBTable = PrivateAttr()
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._db = lancedb_connect(self.uri)
|
||||
self._table = self._db.open_table(self.table_name)
|
||||
|
||||
return super().model_post_init(__context)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
query = self.embedding_function([question])[0]
|
||||
results = (
|
||||
self._table.search(query, vector_column_name=self.vector_column_name)
|
||||
.limit(self.top_k)
|
||||
.select([self.text_column_name])
|
||||
.to_list()
|
||||
)
|
||||
values = [result[self.text_column_name] for result in results]
|
||||
return "\n".join(values)
|
||||
82
src/crewai_tools/base_tool.py
Normal file
82
src/crewai_tools/base_tool.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from langchain.agents import tools as langchain_tools
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Used to tell the model how/when/why to use the tool."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self._run(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Here goes the actual implementation of the tool."""
|
||||
|
||||
def to_langchain(self) -> langchain_tools.Tool:
|
||||
return langchain_tools.Tool(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
func=self._run,
|
||||
)
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
func: Callable
|
||||
"""The function that will be executed when the tool is called."""
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def to_langchain(
|
||||
tools: list[BaseTool | langchain_tools.BaseTool],
|
||||
) -> list[langchain_tools.BaseTool]:
|
||||
return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools]
|
||||
|
||||
|
||||
def tool(*args):
|
||||
"""
|
||||
Decorator to create a tool from a function.
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(f: Callable) -> BaseTool:
|
||||
if f.__doc__ is None:
|
||||
raise ValueError("Function must have a docstring")
|
||||
|
||||
return Tool(
|
||||
name=tool_name,
|
||||
description=f.__doc__,
|
||||
func=f,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
|
||||
if len(args) == 1 and callable(args[0]):
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
if len(args) == 1 and isinstance(args[0], str):
|
||||
return _make_with_name(args[0])
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
def as_tool(f: Any) -> BaseTool:
|
||||
"""
|
||||
Useful for when you create a tool using the @tool decorator and want to use it as a BaseTool.
|
||||
It is a BaseTool, but type inference doesn't know that.
|
||||
"""
|
||||
assert isinstance(f, BaseTool)
|
||||
return cast(BaseTool, f)
|
||||
74
src/crewai_tools/rag_tool.py
Normal file
74
src/crewai_tools/rag_tool.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from crewai_tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def query(self, question: str) -> str:
|
||||
"""Query the knowledge base with a question and return the answer."""
|
||||
|
||||
|
||||
class RagTool(BaseTool):
|
||||
name: str = "Knowledge base"
|
||||
description: str = "A knowledge base that can be used to answer questions."
|
||||
adapter: Adapter
|
||||
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self.adapter.query(args[0])
|
||||
|
||||
def from_file(self, file_path: str):
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
app = App()
|
||||
app.add(file_path, data_type=DataType.TEXT_FILE)
|
||||
|
||||
adapter = EmbedchainAdapter(embedchain_app=app)
|
||||
return RagTool(adapter=adapter)
|
||||
|
||||
def from_directory(self, directory_path: str):
|
||||
from embedchain import App
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
loader = DirectoryLoader(config=dict(recursive=True))
|
||||
|
||||
app = App()
|
||||
app.add(directory_path, loader=loader)
|
||||
|
||||
adapter = EmbedchainAdapter(embedchain_app=app)
|
||||
return RagTool(adapter=adapter)
|
||||
|
||||
def from_web_page(self, url: str):
|
||||
from embedchain import App
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
app = App()
|
||||
app.add(url, data_type=DataType.WEB_PAGE)
|
||||
|
||||
adapter = EmbedchainAdapter(embedchain_app=app)
|
||||
return RagTool(adapter=adapter)
|
||||
|
||||
def from_embedchain(self, config_path: str):
|
||||
from embedchain import App
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
|
||||
app = App.from_config(config_path=config_path)
|
||||
adapter = EmbedchainAdapter(embedchain_app=app)
|
||||
return RagTool(adapter=adapter)
|
||||
Reference in New Issue
Block a user