reduce import time by 6x (#1396)

* reduce import by 6x

* fix linting
This commit is contained in:
Brandon Hancock (bhancock_ai)
2024-10-06 16:55:32 -04:00
committed by GitHub
parent 0dfe3bcb0a
commit 5d8f8cbc79
7 changed files with 43 additions and 29 deletions

View File

@@ -5,11 +5,6 @@ import os
import shutil
from typing import Any, Dict, List, Optional
from embedchain import App
from embedchain.llm.base import BaseLlm
from embedchain.models.data_type import DataType
from embedchain.vectordb.chroma import InvalidDimensionException
from crewai.memory.storage.interface import Storage
from crewai.utilities.paths import db_storage_path
@@ -29,10 +24,6 @@ def suppress_logging(
logger.setLevel(original_level)
class FakeLLM(BaseLlm):
pass
class RAGStorage(Storage):
"""
Extends Storage to handle embeddings for memory entries, improving
@@ -74,9 +65,19 @@ class RAGStorage(Storage):
if embedder_config:
config["embedder"] = embedder_config
self.type = type
self.app = App.from_config(config=config)
self.config = config
self.allow_reset = allow_reset
def _initialize_app(self):
from embedchain import App
from embedchain.llm.base import BaseLlm
class FakeLLM(BaseLlm):
pass
self.app = App.from_config(config=self.config)
self.app.llm = FakeLLM()
if allow_reset:
if self.allow_reset:
self.app.reset()
def _sanitize_role(self, role: str) -> str:
@@ -86,6 +87,8 @@ class RAGStorage(Storage):
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
if not hasattr(self, "app"):
self._initialize_app()
self._generate_embedding(value, metadata)
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
@@ -95,6 +98,10 @@ class RAGStorage(Storage):
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
if not hasattr(self, "app"):
self._initialize_app()
from embedchain.vectordb.chroma import InvalidDimensionException
with suppress_logging():
try:
results = (
@@ -108,6 +115,10 @@ class RAGStorage(Storage):
return [r for r in results if r["metadata"]["score"] >= score_threshold]
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
if not hasattr(self, "app"):
self._initialize_app()
from embedchain.models.data_type import DataType
self.app.add(text, data_type=DataType.TEXT, metadata=metadata)
def reset(self) -> None: