Custom model config for RAG tools

This commit is contained in:
Gui Vieira
2024-03-19 18:47:13 -03:00
parent 73cae1997d
commit 1c8d010601
20 changed files with 704 additions and 452 deletions

View File

@@ -1,12 +1,25 @@
from typing import Any from typing import Any
from embedchain import App
from crewai_tools.tools.rag.rag_tool import Adapter from crewai_tools.tools.rag.rag_tool import Adapter
class EmbedchainAdapter(Adapter): class EmbedchainAdapter(Adapter):
embedchain_app: Any embedchain_app: App
summarize: bool = False summarize: bool = False
def query(self, question: str) -> str: def query(self, question: str) -> str:
result, sources = self.embedchain_app.query(question, citations=True, dry_run=(not self.summarize)) result, sources = self.embedchain_app.query(
question, citations=True, dry_run=(not self.summarize)
)
if self.summarize: if self.summarize:
return result return result
return "\n\n".join([source[0] for source in sources]) return "\n\n".join([source[0] for source in sources])
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.embedchain_app.add(*args, **kwargs)

View File

@@ -35,7 +35,7 @@ class LanceDBAdapter(Adapter):
self._db = lancedb_connect(self.uri) self._db = lancedb_connect(self.uri)
self._table = self._db.open_table(self.table_name) self._table = self._db.open_table(self.table_name)
return super().model_post_init(__context) super().model_post_init(__context)
def query(self, question: str) -> str: def query(self, question: str) -> str:
query = self.embedding_function([question])[0] query = self.embedding_function([question])[0]
@@ -47,3 +47,10 @@ class LanceDBAdapter(Adapter):
) )
values = [result[self.text_column_name] for result in results] values = [result[self.text_column_name] for result in results]
return "\n".join(values) return "\n".join(values)
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self._table.add(*args, **kwargs)

View File

@@ -1,28 +1,47 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Type from typing import Any, Callable, Optional, Type
from pydantic import BaseModel, model_validator from langchain_core.tools import StructuredTool
from pydantic import BaseModel, ConfigDict, Field, validator
from pydantic.v1 import BaseModel as V1BaseModel from pydantic.v1 import BaseModel as V1BaseModel
from langchain_core.tools import StructuredTool
class BaseTool(BaseModel, ABC): class BaseTool(BaseModel, ABC):
class _ArgsSchemaPlaceholder(V1BaseModel):
pass
model_config = ConfigDict()
name: str name: str
"""The unique name of the tool that clearly communicates its purpose.""" """The unique name of the tool that clearly communicates its purpose."""
description: str description: str
"""Used to tell the model how/when/why to use the tool.""" """Used to tell the model how/when/why to use the tool."""
args_schema: Optional[Type[V1BaseModel]] = None args_schema: Type[V1BaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
"""The schema for the arguments that the tool accepts.""" """The schema for the arguments that the tool accepts."""
description_updated: bool = False description_updated: bool = False
"""Flag to check if the description has been updated.""" """Flag to check if the description has been updated."""
cache_function: Optional[Callable] = lambda: True cache_function: Optional[Callable] = lambda: True
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.""" """Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
@model_validator(mode="after") @validator("args_schema", always=True, pre=True)
def _check_args_schema(self): def _default_args_schema(cls, v: Type[V1BaseModel]) -> Type[V1BaseModel]:
self._set_args_schema() if not isinstance(v, cls._ArgsSchemaPlaceholder):
return v
return type(
f"{cls.__name__}Schema",
(V1BaseModel,),
{
"__annotations__": {
k: v for k, v in cls._run.__annotations__.items() if k != "return"
},
},
)
def model_post_init(self, __context: Any) -> None:
self._generate_description() self._generate_description()
return self
super().model_post_init(__context)
def run( def run(
self, self,
@@ -57,16 +76,20 @@ class BaseTool(BaseModel, ABC):
(V1BaseModel,), (V1BaseModel,),
{ {
"__annotations__": { "__annotations__": {
k: v for k, v in self._run.__annotations__.items() if k != 'return' k: v
for k, v in self._run.__annotations__.items()
if k != "return"
}, },
}, },
) )
def _generate_description(self): def _generate_description(self):
args = [] args = []
for arg, attribute in self.args_schema.schema()['properties'].items(): for arg, attribute in self.args_schema.schema()["properties"].items():
args.append(f"{arg}: '{attribute['type']}'") if "type" in attribute:
args.append(f"{arg}: '{attribute['type']}'")
description = self.description.replace('\n', ' ') description = self.description.replace("\n", " ")
self.description = f"{self.name}({', '.join(args)}) - {description}" self.description = f"{self.name}({', '.join(args)}) - {description}"
@@ -93,19 +116,19 @@ def tool(*args):
def _make_tool(f: Callable) -> BaseTool: def _make_tool(f: Callable) -> BaseTool:
if f.__doc__ is None: if f.__doc__ is None:
raise ValueError("Function must have a docstring") raise ValueError("Function must have a docstring")
if f.__annotations__ is None:
raise ValueError("Function must have type annotations")
args_schema = None class_name = "".join(tool_name.split()).title()
if f.__annotations__: args_schema = type(
class_name = "".join(tool_name.split()).title() class_name,
args_schema = type( (V1BaseModel,),
class_name, {
(V1BaseModel,), "__annotations__": {
{ k: v for k, v in f.__annotations__.items() if k != "return"
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != 'return'
},
}, },
) },
)
return Tool( return Tool(
name=tool_name, name=tool_name,
@@ -120,4 +143,4 @@ def tool(*args):
return _make_with_name(args[0].__name__)(args[0]) return _make_with_name(args[0].__name__)(args[0])
if len(args) == 1 and isinstance(args[0], str): if len(args) == 1 and isinstance(args[0], str):
return _make_with_name(args[0]) return _make_with_name(args[0])
raise ValueError("Invalid arguments") raise ValueError("Invalid arguments")

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedCodeDocsSearchToolSchema(BaseModel): class FixedCodeDocsSearchToolSchema(BaseModel):
"""Input for CodeDocsSearchTool.""" """Input for CodeDocsSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the Code Docs content")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the Code Docs content",
)
class CodeDocsSearchToolSchema(FixedCodeDocsSearchToolSchema): class CodeDocsSearchToolSchema(FixedCodeDocsSearchToolSchema):
"""Input for CodeDocsSearchTool.""" """Input for CodeDocsSearchTool."""
docs_url: str = Field(..., description="Mandatory docs_url path you want to search")
docs_url: str = Field(..., description="Mandatory docs_url path you want to search")
class CodeDocsSearchTool(RagTool): class CodeDocsSearchTool(RagTool):
name: str = "Search a Code Docs content" name: str = "Search a Code Docs content"
description: str = "A tool that can be used to semantic search a query from a Code Docs content." description: str = (
summarize: bool = False "A tool that can be used to semantic search a query from a Code Docs content."
args_schema: Type[BaseModel] = CodeDocsSearchToolSchema )
docs_url: Optional[str] = None args_schema: Type[BaseModel] = CodeDocsSearchToolSchema
def __init__(self, docs_url: Optional[str] = None, **kwargs): def __init__(self, docs_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if docs_url is not None: if docs_url is not None:
self.docs_url = docs_url self.add(docs_url)
self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content." self.description = f"A tool that can be used to semantic search a query the {docs_url} Code Docs content."
self.args_schema = FixedCodeDocsSearchToolSchema self.args_schema = FixedCodeDocsSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
docs_url = kwargs.get('docs_url', self.docs_url) kwargs["data_type"] = DataType.DOCS_SITE
self.app = App() super().add(*args, **kwargs)
self.app.add(docs_url, data_type=DataType.DOCS_SITE)
return super()._run(query=search_query) def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "docs_url" in kwargs:
self.add(kwargs["docs_url"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedCSVSearchToolSchema(BaseModel): class FixedCSVSearchToolSchema(BaseModel):
"""Input for CSVSearchTool.""" """Input for CSVSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the CSV's content")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the CSV's content",
)
class CSVSearchToolSchema(FixedCSVSearchToolSchema): class CSVSearchToolSchema(FixedCSVSearchToolSchema):
"""Input for CSVSearchTool.""" """Input for CSVSearchTool."""
csv: str = Field(..., description="Mandatory csv path you want to search")
csv: str = Field(..., description="Mandatory csv path you want to search")
class CSVSearchTool(RagTool): class CSVSearchTool(RagTool):
name: str = "Search a CSV's content" name: str = "Search a CSV's content"
description: str = "A tool that can be used to semantic search a query from a CSV's content." description: str = (
summarize: bool = False "A tool that can be used to semantic search a query from a CSV's content."
args_schema: Type[BaseModel] = CSVSearchToolSchema )
csv: Optional[str] = None args_schema: Type[BaseModel] = CSVSearchToolSchema
def __init__(self, csv: Optional[str] = None, **kwargs): def __init__(self, csv: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if csv is not None: if csv is not None:
self.csv = csv self.add(csv)
self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content." self.description = f"A tool that can be used to semantic search a query the {csv} CSV's content."
self.args_schema = FixedCSVSearchToolSchema self.args_schema = FixedCSVSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
csv = kwargs.get('csv', self.csv) kwargs["data_type"] = DataType.CSV
self.app = App() super().add(*args, **kwargs)
self.app.add(csv, data_type=DataType.CSV)
return super()._run(query=search_query) def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "csv" in kwargs:
self.add(kwargs["csv"])

View File

@@ -1,42 +1,52 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.loaders.directory_loader import DirectoryLoader from embedchain.loaders.directory_loader import DirectoryLoader
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedDirectorySearchToolSchema(BaseModel): class FixedDirectorySearchToolSchema(BaseModel):
"""Input for DirectorySearchTool.""" """Input for DirectorySearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the directory's content")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the directory's content",
)
class DirectorySearchToolSchema(FixedDirectorySearchToolSchema): class DirectorySearchToolSchema(FixedDirectorySearchToolSchema):
"""Input for DirectorySearchTool.""" """Input for DirectorySearchTool."""
directory: str = Field(..., description="Mandatory directory you want to search")
directory: str = Field(..., description="Mandatory directory you want to search")
class DirectorySearchTool(RagTool): class DirectorySearchTool(RagTool):
name: str = "Search a directory's content" name: str = "Search a directory's content"
description: str = "A tool that can be used to semantic search a query from a directory's content." description: str = (
summarize: bool = False "A tool that can be used to semantic search a query from a directory's content."
args_schema: Type[BaseModel] = DirectorySearchToolSchema )
directory: Optional[str] = None args_schema: Type[BaseModel] = DirectorySearchToolSchema
def __init__(self, directory: Optional[str] = None, **kwargs): def __init__(self, directory: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if directory is not None: if directory is not None:
self.directory = directory self.add(directory)
self.description = f"A tool that can be used to semantic search a query the {directory} directory's content." self.description = f"A tool that can be used to semantic search a query the {directory} directory's content."
self.args_schema = FixedDirectorySearchToolSchema self.args_schema = FixedDirectorySearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
directory = kwargs.get('directory', self.directory) kwargs["loader"] = DirectoryLoader(config=dict(recursive=True))
loader = DirectoryLoader(config=dict(recursive=True)) super().add(*args, **kwargs)
self.app = App()
self.app.add(directory, loader=loader) def _before_run(
return super()._run(query=search_query) self,
query: str,
**kwargs: Any,
) -> Any:
if "directory" in kwargs:
self.add(kwargs["directory"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedDOCXSearchToolSchema(BaseModel): class FixedDOCXSearchToolSchema(BaseModel):
"""Input for DOCXSearchTool.""" """Input for DOCXSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the DOCX's content")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the DOCX's content",
)
class DOCXSearchToolSchema(FixedDOCXSearchToolSchema): class DOCXSearchToolSchema(FixedDOCXSearchToolSchema):
"""Input for DOCXSearchTool.""" """Input for DOCXSearchTool."""
docx: str = Field(..., description="Mandatory docx path you want to search")
docx: str = Field(..., description="Mandatory docx path you want to search")
class DOCXSearchTool(RagTool): class DOCXSearchTool(RagTool):
name: str = "Search a DOCX's content" name: str = "Search a DOCX's content"
description: str = "A tool that can be used to semantic search a query from a DOCX's content." description: str = (
summarize: bool = False "A tool that can be used to semantic search a query from a DOCX's content."
args_schema: Type[BaseModel] = DOCXSearchToolSchema )
docx: Optional[str] = None args_schema: Type[BaseModel] = DOCXSearchToolSchema
def __init__(self, docx: Optional[str] = None, **kwargs): def __init__(self, docx: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if docx is not None: if docx is not None:
self.docx = docx self.add(docx)
self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content." self.description = f"A tool that can be used to semantic search a query the {docx} DOCX's content."
self.args_schema = FixedDOCXSearchToolSchema self.args_schema = FixedDOCXSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
docx = kwargs.get('docx', self.docx) kwargs["data_type"] = DataType.DOCX
self.app = App() super().add(*args, **kwargs)
self.app.add(docx, data_type=DataType.DOCX)
return super()._run(query=search_query) def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "docx" in kwargs:
self.add(kwargs["docx"])

View File

@@ -1,46 +1,58 @@
from typing import Optional, Type, List, Any from typing import Any, List, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.loaders.github import GithubLoader from embedchain.loaders.github import GithubLoader
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedGithubSearchToolSchema(BaseModel): class FixedGithubSearchToolSchema(BaseModel):
"""Input for GithubSearchTool.""" """Input for GithubSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the github repo's content")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the github repo's content",
)
class GithubSearchToolSchema(FixedGithubSearchToolSchema): class GithubSearchToolSchema(FixedGithubSearchToolSchema):
"""Input for GithubSearchTool.""" """Input for GithubSearchTool."""
github_repo: str = Field(..., description="Mandatory github you want to search")
content_types: List[str] = Field(..., description="Mandatory content types you want to be inlcuded search, options: [code, repo, pr, issue]") github_repo: str = Field(..., description="Mandatory github you want to search")
content_types: List[str] = Field(
...,
description="Mandatory content types you want to be inlcuded search, options: [code, repo, pr, issue]",
)
class GithubSearchTool(RagTool): class GithubSearchTool(RagTool):
name: str = "Search a github repo's content" name: str = "Search a github repo's content"
description: str = "A tool that can be used to semantic search a query from a github repo's content." description: str = "A tool that can be used to semantic search a query from a github repo's content."
summarize: bool = False summarize: bool = False
gh_token: str = None gh_token: str
args_schema: Type[BaseModel] = GithubSearchToolSchema args_schema: Type[BaseModel] = GithubSearchToolSchema
github_repo: Optional[str] = None content_types: List[str]
content_types: List[str]
def __init__(self, github_repo: Optional[str] = None, **kwargs): def __init__(self, github_repo: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if github_repo is not None: if github_repo is not None:
self.github_repo = github_repo self.add(github_repo)
self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content." self.description = f"A tool that can be used to semantic search a query the {github_repo} github repo's content."
self.args_schema = FixedGithubSearchToolSchema self.args_schema = FixedGithubSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
github_repo = kwargs.get('github_repo', self.github_repo) kwargs["data_type"] = "github"
loader = GithubLoader(config={"token": self.gh_token}) kwargs["loader"] = GithubLoader(config={"token": self.gh_token})
app = App() super().add(*args, **kwargs)
app.add(f"repo:{github_repo} type:{','.join(self.content_types)}", data_type="github", loader=loader)
self.app = app def _before_run(
return super()._run(query=search_query) self,
query: str,
**kwargs: Any,
) -> Any:
if "github_repo" in kwargs:
self.add(kwargs["github_repo"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedJSONSearchToolSchema(BaseModel): class FixedJSONSearchToolSchema(BaseModel):
"""Input for JSONSearchTool.""" """Input for JSONSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the JSON's content")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the JSON's content",
)
class JSONSearchToolSchema(FixedJSONSearchToolSchema): class JSONSearchToolSchema(FixedJSONSearchToolSchema):
"""Input for JSONSearchTool.""" """Input for JSONSearchTool."""
json_path: str = Field(..., description="Mandatory json path you want to search")
json_path: str = Field(..., description="Mandatory json path you want to search")
class JSONSearchTool(RagTool): class JSONSearchTool(RagTool):
name: str = "Search a JSON's content" name: str = "Search a JSON's content"
description: str = "A tool that can be used to semantic search a query from a JSON's content." description: str = (
summarize: bool = False "A tool that can be used to semantic search a query from a JSON's content."
args_schema: Type[BaseModel] = JSONSearchToolSchema )
json_path: Optional[str] = None args_schema: Type[BaseModel] = JSONSearchToolSchema
def __init__(self, json_path: Optional[str] = None, **kwargs): def __init__(self, json_path: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if json_path is not None: if json_path is not None:
self.json_path = json_path self.add(json_path)
self.description = f"A tool that can be used to semantic search a query the {json} JSON's content." self.description = f"A tool that can be used to semantic search a query the {json_path} JSON's content."
self.args_schema = FixedJSONSearchToolSchema self.args_schema = FixedJSONSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
json_path = kwargs.get('json_path', self.json_path) kwargs["data_type"] = DataType.JSON
self.app = App() super().add(*args, **kwargs)
self.app.add(json_path, data_type=DataType.JSON)
return super()._run(query=search_query) def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "json_path" in kwargs:
self.add(kwargs["json_path"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedMDXSearchToolSchema(BaseModel): class FixedMDXSearchToolSchema(BaseModel):
"""Input for MDXSearchTool.""" """Input for MDXSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the MDX's content")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the MDX's content",
)
class MDXSearchToolSchema(FixedMDXSearchToolSchema): class MDXSearchToolSchema(FixedMDXSearchToolSchema):
"""Input for MDXSearchTool.""" """Input for MDXSearchTool."""
mdx: str = Field(..., description="Mandatory mdx path you want to search")
mdx: str = Field(..., description="Mandatory mdx path you want to search")
class MDXSearchTool(RagTool): class MDXSearchTool(RagTool):
name: str = "Search a MDX's content" name: str = "Search a MDX's content"
description: str = "A tool that can be used to semantic search a query from a MDX's content." description: str = (
summarize: bool = False "A tool that can be used to semantic search a query from a MDX's content."
args_schema: Type[BaseModel] = MDXSearchToolSchema )
mdx: Optional[str] = None args_schema: Type[BaseModel] = MDXSearchToolSchema
def __init__(self, mdx: Optional[str] = None, **kwargs): def __init__(self, mdx: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if mdx is not None: if mdx is not None:
self.mdx = mdx self.add(mdx)
self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content." self.description = f"A tool that can be used to semantic search a query the {mdx} MDX's content."
self.args_schema = FixedMDXSearchToolSchema self.args_schema = FixedMDXSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
mdx = kwargs.get('mdx', self.mdx) kwargs["data_type"] = DataType.MDX
self.app = App() super().add(*args, **kwargs)
self.app.add(mdx, data_type=DataType.MDX)
return super()._run(query=search_query) def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "mdx" in kwargs:
self.add(kwargs["mdx"])

View File

@@ -1,41 +1,51 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedPDFSearchToolSchema(BaseModel): class FixedPDFSearchToolSchema(BaseModel):
"""Input for PDFSearchTool.""" """Input for PDFSearchTool."""
query: str = Field(..., description="Mandatory query you want to use to search the PDF's content")
query: str = Field(
..., description="Mandatory query you want to use to search the PDF's content"
)
class PDFSearchToolSchema(FixedPDFSearchToolSchema): class PDFSearchToolSchema(FixedPDFSearchToolSchema):
"""Input for PDFSearchTool.""" """Input for PDFSearchTool."""
pdf: str = Field(..., description="Mandatory pdf path you want to search")
pdf: str = Field(..., description="Mandatory pdf path you want to search")
class PDFSearchTool(RagTool): class PDFSearchTool(RagTool):
name: str = "Search a PDF's content" name: str = "Search a PDF's content"
description: str = "A tool that can be used to semantic search a query from a PDF's content." description: str = (
summarize: bool = False "A tool that can be used to semantic search a query from a PDF's content."
args_schema: Type[BaseModel] = PDFSearchToolSchema )
pdf: Optional[str] = None args_schema: Type[BaseModel] = PDFSearchToolSchema
def __init__(self, pdf: Optional[str] = None, **kwargs): def __init__(self, pdf: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if pdf is not None: if pdf is not None:
self.pdf = pdf self.add(pdf)
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content." self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
self.args_schema = FixedPDFSearchToolSchema self.args_schema = FixedPDFSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
pdf = kwargs.get('pdf', self.pdf) kwargs["data_type"] = DataType.PDF_FILE
self.app = App() super().add(*args, **kwargs)
self.app.add(pdf, data_type=DataType.PDF_FILE)
return super()._run(query=query) def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "pdf" in kwargs:
self.add(kwargs["pdf"])

View File

@@ -1,45 +1,37 @@
from typing import Optional, Type, Any from typing import Any, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.loaders.postgres import PostgresLoader from embedchain.loaders.postgres import PostgresLoader
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class PGSearchToolSchema(BaseModel): class PGSearchToolSchema(BaseModel):
"""Input for PGSearchTool.""" """Input for PGSearchTool."""
search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content")
search_query: str = Field(
...,
description="Mandatory semantic search query you want to use to search the database's content",
)
class PGSearchTool(RagTool): class PGSearchTool(RagTool):
name: str = "Search a database's table content" name: str = "Search a database's table content"
description: str = "A tool that can be used to semantic search a query from a database table's content." description: str = "A tool that can be used to semantic search a query from a database table's content."
summarize: bool = False args_schema: Type[BaseModel] = PGSearchToolSchema
args_schema: Type[BaseModel] = PGSearchToolSchema db_uri: str = Field(..., description="Mandatory database URI")
db_uri: str = Field(..., description="Mandatory database URI")
table_name: str = Field(..., description="Mandatory table name")
search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content")
def __init__(self, table_name: Optional[str] = None, **kwargs): def __init__(self, table_name: str, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if table_name is not None: self.add(table_name)
self.table_name = table_name self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content." self._generate_description()
self._generate_description()
else:
raise('To use PGSearchTool, you must provide a `table_name` argument')
def _run( def add(
self, self,
search_query: str, table_name: str,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
kwargs["data_type"] = "postgres"
config = { "url": self.db_uri } kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
postgres_loader = PostgresLoader(config=config) super().add(f"SELECT * FROM {table_name};", **kwargs)
app = App()
app.add(
f"SELECT * FROM {self.table_name};",
data_type='postgres',
loader=postgres_loader
)
return super()._run(query=search_query)

View File

@@ -48,9 +48,6 @@ rag_tool = RagTool().from_directory('path/to/your/directory')
# Example: Loading from a web page # Example: Loading from a web page
rag_tool = RagTool().from_web_page('https://example.com') rag_tool = RagTool().from_web_page('https://example.com')
# Example: Loading from an Embedchain configuration
rag_tool = RagTool().from_embedchain('path/to/your/config.json')
``` ```
## **Contribution** ## **Contribution**
@@ -61,4 +58,4 @@ Contributions to RagTool and the broader CrewAI tools ecosystem are welcome. To
RagTool is open-source and available under the MIT license. RagTool is open-source and available under the MIT license.
Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better. Thank you for considering RagTool for your knowledge base needs. Your contributions and feedback are invaluable to making RagTool even better.

View File

@@ -1,38 +1,71 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Optional from typing import Any
from pydantic.v1 import BaseModel, ConfigDict from pydantic import BaseModel, Field, model_validator
from crewai_tools.tools.base_tool import BaseTool from crewai_tools.tools.base_tool import BaseTool
class Adapter(BaseModel, ABC): class Adapter(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True) class Config:
arbitrary_types_allowed = True
@abstractmethod @abstractmethod
def query(self, question: str) -> str: def query(self, question: str) -> str:
"""Query the knowledge base with a question and return the answer.""" """Query the knowledge base with a question and return the answer."""
@abstractmethod
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
"""Add content to the knowledge base."""
class RagTool(BaseTool): class RagTool(BaseTool):
model_config = ConfigDict(arbitrary_types_allowed=True) class _AdapterPlaceholder(Adapter):
def query(self, question: str) -> str:
raise NotImplementedError
def add(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError
name: str = "Knowledge base" name: str = "Knowledge base"
description: str = "A knowledge base that can be used to answer questions." description: str = "A knowledge base that can be used to answer questions."
summarize: bool = False summarize: bool = False
adapter: Optional[Adapter] = None adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
app: Optional[Any] = None config: dict[str, Any] | None = None
@model_validator(mode="after")
def _set_default_adapter(self):
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
from embedchain import App
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App.from_config(config=self.config) if self.config else App()
self.adapter = EmbedchainAdapter(
embedchain_app=app, summarize=self.summarize
)
return self
def add(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.adapter.add(*args, **kwargs)
def _run( def _run(
self, self,
query: str, query: str,
**kwargs: Any,
) -> Any: ) -> Any:
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter self._before_run(query, **kwargs)
self.adapter = EmbedchainAdapter(embedchain_app=self.app, summarize=self.summarize)
return f"Relevant Content:\n{self.adapter.query(query)}" return f"Relevant Content:\n{self.adapter.query(query)}"
def from_embedchain(self, config_path: str): def _before_run(self, query, **kwargs):
from embedchain import App pass
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
app = App.from_config(config_path=config_path)
adapter = EmbedchainAdapter(embedchain_app=app)
return RagTool(name=self.name, description=self.description, adapter=adapter)

View File

@@ -1,40 +1,52 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedTXTSearchToolSchema(BaseModel): class FixedTXTSearchToolSchema(BaseModel):
"""Input for TXTSearchTool.""" """Input for TXTSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the txt's content")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the txt's content",
)
class TXTSearchToolSchema(FixedTXTSearchToolSchema): class TXTSearchToolSchema(FixedTXTSearchToolSchema):
"""Input for TXTSearchTool.""" """Input for TXTSearchTool."""
txt: str = Field(..., description="Mandatory txt path you want to search")
txt: str = Field(..., description="Mandatory txt path you want to search")
class TXTSearchTool(RagTool): class TXTSearchTool(RagTool):
name: str = "Search a txt's content" name: str = "Search a txt's content"
description: str = "A tool that can be used to semantic search a query from a txt's content." description: str = (
summarize: bool = False "A tool that can be used to semantic search a query from a txt's content."
args_schema: Type[BaseModel] = TXTSearchToolSchema )
txt: Optional[str] = None args_schema: Type[BaseModel] = TXTSearchToolSchema
def __init__(self, txt: Optional[str] = None, **kwargs): def __init__(self, txt: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if txt is not None: if txt is not None:
self.txt = txt self.add(txt)
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content." self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
self.args_schema = FixedTXTSearchToolSchema self.args_schema = FixedTXTSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
txt = kwargs.get('txt', self.txt) kwargs["data_type"] = DataType.TEXT_FILE
self.app = App() super().add(*args, **kwargs)
self.app.add(txt, data_type=DataType.TEXT_FILE)
return super()._run(query=search_query) def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "txt" in kwargs:
self.add(kwargs["txt"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedWebsiteSearchToolSchema(BaseModel): class FixedWebsiteSearchToolSchema(BaseModel):
"""Input for WebsiteSearchTool.""" """Input for WebsiteSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search a specific website")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search a specific website",
)
class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema): class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
"""Input for WebsiteSearchTool.""" """Input for WebsiteSearchTool."""
website: str = Field(..., description="Mandatory valid website URL you want to search on")
website: str = Field(
..., description="Mandatory valid website URL you want to search on"
)
class WebsiteSearchTool(RagTool): class WebsiteSearchTool(RagTool):
name: str = "Search in a specific website" name: str = "Search in a specific website"
description: str = "A tool that can be used to semantic search a query from a specific URL content." description: str = "A tool that can be used to semantic search a query from a specific URL content."
summarize: bool = False args_schema: Type[BaseModel] = WebsiteSearchToolSchema
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
website: Optional[str] = None
def __init__(self, website: Optional[str] = None, **kwargs): def __init__(self, website: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if website is not None: if website is not None:
self.website = website self.add(website)
self.description = f"A tool that can be used to semantic search a query from {website} website content." self.description = f"A tool that can be used to semantic search a query from {website} website content."
self.args_schema = FixedWebsiteSearchToolSchema self.args_schema = FixedWebsiteSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
website = kwargs.get('website', self.website) kwargs["data_type"] = DataType.WEB_PAGE
self.app = App() super().add(*args, **kwargs)
self.app.add(website, data_type=DataType.WEB_PAGE)
return super()._run(query=search_query) def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "website" in kwargs:
self.add(kwargs["website"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedXMLSearchToolSchema(BaseModel): class FixedXMLSearchToolSchema(BaseModel):
"""Input for XMLSearchTool.""" """Input for XMLSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the XML's content")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the XML's content",
)
class XMLSearchToolSchema(FixedXMLSearchToolSchema): class XMLSearchToolSchema(FixedXMLSearchToolSchema):
"""Input for XMLSearchTool.""" """Input for XMLSearchTool."""
xml: str = Field(..., description="Mandatory xml path you want to search")
xml: str = Field(..., description="Mandatory xml path you want to search")
class XMLSearchTool(RagTool): class XMLSearchTool(RagTool):
name: str = "Search a XML's content" name: str = "Search a XML's content"
description: str = "A tool that can be used to semantic search a query from a XML's content." description: str = (
summarize: bool = False "A tool that can be used to semantic search a query from a XML's content."
args_schema: Type[BaseModel] = XMLSearchToolSchema )
xml: Optional[str] = None args_schema: Type[BaseModel] = XMLSearchToolSchema
def __init__(self, xml: Optional[str] = None, **kwargs): def __init__(self, xml: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if xml is not None: if xml is not None:
self.xml = xml self.add(xml)
self.description = f"A tool that can be used to semantic search a query the {xml} XML's content." self.description = f"A tool that can be used to semantic search a query the {xml} XML's content."
self.args_schema = FixedXMLSearchToolSchema self.args_schema = FixedXMLSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
xml = kwargs.get('xml', self.xml) kwargs["data_type"] = DataType.XML
self.app = App() super().add(*args, **kwargs)
self.app.add(xml, data_type=DataType.XML)
return super()._run(query=search_query) def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "xml" in kwargs:
self.add(kwargs["xml"])

View File

@@ -1,43 +1,55 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedYoutubeChannelSearchToolSchema(BaseModel): class FixedYoutubeChannelSearchToolSchema(BaseModel):
"""Input for YoutubeChannelSearchTool.""" """Input for YoutubeChannelSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the Youtube Channels content")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the Youtube Channels content",
)
class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema): class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
"""Input for YoutubeChannelSearchTool.""" """Input for YoutubeChannelSearchTool."""
youtube_channel_handle: str = Field(..., description="Mandatory youtube_channel_handle path you want to search")
youtube_channel_handle: str = Field(
..., description="Mandatory youtube_channel_handle path you want to search"
)
class YoutubeChannelSearchTool(RagTool): class YoutubeChannelSearchTool(RagTool):
name: str = "Search a Youtube Channels content" name: str = "Search a Youtube Channels content"
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content." description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
summarize: bool = False args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
youtube_channel_handle: Optional[str] = None
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs): def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if youtube_channel_handle is not None: if youtube_channel_handle is not None:
self.youtube_channel_handle = youtube_channel_handle self.add(youtube_channel_handle)
self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content." self.description = f"A tool that can be used to semantic search a query the {youtube_channel_handle} Youtube Channels content."
self.args_schema = FixedYoutubeChannelSearchToolSchema self.args_schema = FixedYoutubeChannelSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, youtube_channel_handle: str,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
youtube_channel_handle = kwargs.get('youtube_channel_handle', self.youtube_channel_handle) if not youtube_channel_handle.startswith("@"):
if not youtube_channel_handle.startswith("@"): youtube_channel_handle = f"@{youtube_channel_handle}"
youtube_channel_handle = f"@{youtube_channel_handle}"
self.app = App() kwargs["data_type"] = DataType.YOUTUBE_CHANNEL
self.app.add(youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL) super().add(youtube_channel_handle, **kwargs)
return super()._run(query=search_query)
def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "youtube_channel_handle" in kwargs:
self.add(kwargs["youtube_channel_handle"])

View File

@@ -1,41 +1,52 @@
from typing import Optional, Type, Any from typing import Any, Optional, Type
from pydantic.v1 import BaseModel, Field
from embedchain import App
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool from ..rag.rag_tool import RagTool
class FixedYoutubeVideoSearchToolSchema(BaseModel): class FixedYoutubeVideoSearchToolSchema(BaseModel):
"""Input for YoutubeVideoSearchTool.""" """Input for YoutubeVideoSearchTool."""
search_query: str = Field(..., description="Mandatory search query you want to use to search the Youtube Video content")
search_query: str = Field(
...,
description="Mandatory search query you want to use to search the Youtube Video content",
)
class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema): class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
"""Input for YoutubeVideoSearchTool.""" """Input for YoutubeVideoSearchTool."""
youtube_video_url: str = Field(..., description="Mandatory youtube_video_url path you want to search")
youtube_video_url: str = Field(
..., description="Mandatory youtube_video_url path you want to search"
)
class YoutubeVideoSearchTool(RagTool): class YoutubeVideoSearchTool(RagTool):
name: str = "Search a Youtube Video content" name: str = "Search a Youtube Video content"
description: str = "A tool that can be used to semantic search a query from a Youtube Video content." description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
summarize: bool = False args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
youtube_video_url: Optional[str] = None
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs): def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if youtube_video_url is not None: if youtube_video_url is not None:
self.youtube_video_url = youtube_video_url self.add(youtube_video_url)
self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content." self.description = f"A tool that can be used to semantic search a query the {youtube_video_url} Youtube Video content."
self.args_schema = FixedYoutubeVideoSearchToolSchema self.args_schema = FixedYoutubeVideoSearchToolSchema
self._generate_description()
def _run( def add(
self, self,
search_query: str, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> None:
youtube_video_url = kwargs.get('youtube_video_url', self.youtube_video_url) kwargs["data_type"] = DataType.YOUTUBE_VIDEO
self.app = App() super().add(*args, **kwargs)
self.app.add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
return super()._run(query=search_query) def _before_run(
self,
query: str,
**kwargs: Any,
) -> Any:
if "youtube_video_url" in kwargs:
self.add(kwargs["youtube_video_url"])

View File

@@ -0,0 +1,43 @@
import os
from tempfile import NamedTemporaryFile
from typing import cast
from unittest import mock
from pytest import fixture
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
from crewai_tools.tools.rag.rag_tool import RagTool
@fixture(autouse=True)
def mock_embedchain_db_uri():
with NamedTemporaryFile() as tmp:
uri = f"sqlite:///{tmp.name}"
with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}):
yield
def test_custom_llm_and_embedder():
class MyTool(RagTool):
pass
tool = MyTool(
config=dict(
llm=dict(
provider="openai",
config=dict(model="gpt-3.5-custom"),
),
embedder=dict(
provider="openai",
config=dict(model="text-embedding-3-custom"),
),
)
)
assert tool.adapter is not None
assert isinstance(tool.adapter, EmbedchainAdapter)
adapter = cast(EmbedchainAdapter, tool.adapter)
assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom"
assert (
adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom"
)