Reset memory (#958)

* reseting memory on cli

* using storage.reset

* deleting memories on command

* added tests

* handle when no flags are used

* added docs
This commit is contained in:
Lorenze Jay
2024-07-18 09:29:42 -07:00
committed by GitHub
parent 3f2f832e8d
commit 405d45c3fb
10 changed files with 231 additions and 4 deletions

View File

@@ -23,3 +23,9 @@ class EntityMemory(Memory):
"""Saves an entity item into the SQLite storage."""
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
def reset(self) -> None:
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the entity memory: {e}")

View File

@@ -30,3 +30,6 @@ class LongTermMemory(Memory):
def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]:
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def reset(self) -> None:
self.storage.reset()

View File

@@ -18,8 +18,16 @@ class ShortTermMemory(Memory):
)
super().__init__(storage)
def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
def save(self, item: ShortTermMemoryItem) -> None:
super().save(item.data, item.metadata, item.agent)
def search(self, query: str, score_threshold: float = 0.35):
return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
def reset(self) -> None:
try:
self.storage.reset()
except Exception as e:
raise Exception(
f"An error occurred while resetting the short-term memory: {e}"
)

View File

@@ -9,3 +9,6 @@ class Storage:
def search(self, key: str) -> Dict[str, Any]: # type: ignore
pass
def reset(self) -> None:
pass

View File

@@ -103,3 +103,20 @@ class LTMSQLiteStorage:
color="red",
)
return None
def reset(
self,
) -> None:
"""Resets the LTM table with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM long_term_memories")
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
return None

View File

@@ -2,6 +2,7 @@ import contextlib
import io
import logging
import os
import shutil
from typing import Any, Dict, List, Optional
from embedchain import App
@@ -71,13 +72,13 @@ class RAGStorage(Storage):
if embedder_config:
config["embedder"] = embedder_config
self.type = type
self.app = App.from_config(config=config)
self.app.llm = FakeLLM()
if allow_reset:
self.app.reset()
def save(self, value: Any, metadata: Dict[str, Any]) -> None: # type: ignore # BUG?: Should be save(key, value, metadata) Signature of "save" incompatible with supertype "Storage"
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
self._generate_embedding(value, metadata)
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
@@ -102,3 +103,11 @@ class RAGStorage(Storage):
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
with suppress_logging():
self.app.add(text, data_type="text", metadata=metadata)
def reset(self) -> None:
try:
shutil.rmtree(f"{db_storage_path()}/{self.type}")
except Exception as e:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)