Reset memory (#958)

* reseting memory on cli

* using storage.reset

* deleting memories on command

* added tests

* handle when no flags are used

* added docs
This commit is contained in:
Lorenze Jay
2024-07-18 09:29:42 -07:00
committed by GitHub
parent 3f2f832e8d
commit 405d45c3fb
10 changed files with 231 additions and 4 deletions

View File

@@ -161,6 +161,39 @@ my_crew = Crew(
) )
``` ```
### Resetting Memory
```sh
crewai reset_memories [OPTIONS]
```
#### Resetting Memory Options
- **`-l, --long`**
- **Description:** Reset LONG TERM memory.
- **Type:** Flag (boolean)
- **Default:** False
- **`-s, --short`**
- **Description:** Reset SHORT TERM memory.
- **Type:** Flag (boolean)
- **Default:** False
- **`-e, --entities`**
- **Description:** Reset ENTITIES memory.
- **Type:** Flag (boolean)
- **Default:** False
- **`-k, --kickoff-outputs`**
- **Description:** Reset LATEST KICKOFF TASK OUTPUTS.
- **Type:** Flag (boolean)
- **Default:** False
- **`-a, --all`**
- **Description:** Reset ALL memories.
- **Type:** Flag (boolean)
- **Default:** False
## Benefits of Using crewAI's Memory System ## Benefits of Using crewAI's Memory System
- **Adaptive Learning:** Crews become more efficient over time, adapting to new information and refining their approach to tasks. - **Adaptive Learning:** Crews become more efficient over time, adapting to new information and refining their approach to tasks.
- **Enhanced Personalization:** Memory enables agents to remember user preferences and historical interactions, leading to personalized experiences. - **Enhanced Personalization:** Memory enables agents to remember user preferences and historical interactions, leading to personalized experiences.

View File

@@ -9,6 +9,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
from .create_crew import create_crew from .create_crew import create_crew
from .train_crew import train_crew from .train_crew import train_crew
from .replay_from_task import replay_task_command from .replay_from_task import replay_task_command
from .reset_memories_command import reset_memories_command
@click.group() @click.group()
@@ -99,5 +100,31 @@ def log_tasks_outputs() -> None:
click.echo(f"An error occurred while logging task outputs: {e}", err=True) click.echo(f"An error occurred while logging task outputs: {e}", err=True)
@crewai.command()
@click.option("-l", "--long", is_flag=True, help="Reset LONG TERM memory")
@click.option("-s", "--short", is_flag=True, help="Reset SHORT TERM memory")
@click.option("-e", "--entities", is_flag=True, help="Reset ENTITIES memory")
@click.option(
"-k",
"--kickoff-outputs",
is_flag=True,
help="Reset LATEST KICKOFF TASK OUTPUTS",
)
@click.option("-a", "--all", is_flag=True, help="Reset ALL memories")
def reset_memories(long, short, entities, kickoff_outputs, all):
"""
Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved.
"""
try:
if not all and not (long or short or entities or kickoff_outputs):
click.echo(
"Please specify at least one memory type to reset using the appropriate flags."
)
return
reset_memories_command(long, short, entities, kickoff_outputs, all)
except Exception as e:
click.echo(f"An error occurred while resetting memories: {e}", err=True)
if __name__ == "__main__": if __name__ == "__main__":
crewai() crewai()

View File

@@ -0,0 +1,45 @@
import subprocess
import click
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None:
"""
Replay the crew execution from a specific task.
Args:
task_id (str): The ID of the task to replay from.
"""
try:
if all:
ShortTermMemory().reset()
EntityMemory().reset()
LongTermMemory().reset()
TaskOutputStorageHandler().reset()
click.echo("All memories have been reset.")
else:
if long:
LongTermMemory().reset()
click.echo("Long term memory has been reset.")
if short:
ShortTermMemory().reset()
click.echo("Short term memory has been reset.")
if entity:
EntityMemory().reset()
click.echo("Entity memory has been reset.")
if kickoff_outputs:
TaskOutputStorageHandler().reset()
click.echo("Latest Kickoff outputs stored has been reset.")
except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
click.echo(e.output, err=True)
except Exception as e:
click.echo(f"An unexpected error occurred: {e}", err=True)

View File

@@ -23,3 +23,9 @@ class EntityMemory(Memory):
"""Saves an entity item into the SQLite storage.""" """Saves an entity item into the SQLite storage."""
data = f"{item.name}({item.type}): {item.description}" data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata) super().save(data, item.metadata)
def reset(self) -> None:
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the entity memory: {e}")

View File

@@ -30,3 +30,6 @@ class LongTermMemory(Memory):
def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]: def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]:
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load" return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def reset(self) -> None:
self.storage.reset()

View File

@@ -18,8 +18,16 @@ class ShortTermMemory(Memory):
) )
super().__init__(storage) super().__init__(storage)
def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" def save(self, item: ShortTermMemoryItem) -> None:
super().save(item.data, item.metadata, item.agent) super().save(item.data, item.metadata, item.agent)
def search(self, query: str, score_threshold: float = 0.35): def search(self, query: str, score_threshold: float = 0.35):
return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
def reset(self) -> None:
try:
self.storage.reset()
except Exception as e:
raise Exception(
f"An error occurred while resetting the short-term memory: {e}"
)

View File

@@ -9,3 +9,6 @@ class Storage:
def search(self, key: str) -> Dict[str, Any]: # type: ignore def search(self, key: str) -> Dict[str, Any]: # type: ignore
pass pass
def reset(self) -> None:
pass

View File

@@ -103,3 +103,20 @@ class LTMSQLiteStorage:
color="red", color="red",
) )
return None return None
def reset(
self,
) -> None:
"""Resets the LTM table with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM long_term_memories")
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
return None

View File

@@ -2,6 +2,7 @@ import contextlib
import io import io
import logging import logging
import os import os
import shutil
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from embedchain import App from embedchain import App
@@ -71,13 +72,13 @@ class RAGStorage(Storage):
if embedder_config: if embedder_config:
config["embedder"] = embedder_config config["embedder"] = embedder_config
self.type = type
self.app = App.from_config(config=config) self.app = App.from_config(config=config)
self.app.llm = FakeLLM() self.app.llm = FakeLLM()
if allow_reset: if allow_reset:
self.app.reset() self.app.reset()
def save(self, value: Any, metadata: Dict[str, Any]) -> None: # type: ignore # BUG?: Should be save(key, value, metadata) Signature of "save" incompatible with supertype "Storage" def save(self, value: Any, metadata: Dict[str, Any]) -> None:
self._generate_embedding(value, metadata) self._generate_embedding(value, metadata)
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage" def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
@@ -102,3 +103,11 @@ class RAGStorage(Storage):
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any: def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
with suppress_logging(): with suppress_logging():
self.app.add(text, data_type="text", metadata=metadata) self.app.add(text, data_type="text", metadata=metadata)
def reset(self) -> None:
try:
shutil.rmtree(f"{db_storage_path()}/{self.type}")
except Exception as e:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)

View File

@@ -3,7 +3,7 @@ from unittest import mock
import pytest import pytest
from click.testing import CliRunner from click.testing import CliRunner
from crewai.cli.cli import train, version from crewai.cli.cli import train, version, reset_memories
@pytest.fixture @pytest.fixture
@@ -41,6 +41,82 @@ def test_train_invalid_string_iterations(train_crew, runner):
) )
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
@mock.patch("crewai.cli.reset_memories_command.EntityMemory")
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler")
def test_reset_all_memories(
MockTaskOutputStorageHandler,
MockLongTermMemory,
MockEntityMemory,
MockShortTermMemory,
runner,
):
result = runner.invoke(reset_memories, ["--all"])
MockShortTermMemory().reset.assert_called_once()
MockEntityMemory().reset.assert_called_once()
MockLongTermMemory().reset.assert_called_once()
MockTaskOutputStorageHandler().reset.assert_called_once()
assert result.output == "All memories have been reset.\n"
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
def test_reset_short_term_memories(MockShortTermMemory, runner):
result = runner.invoke(reset_memories, ["-s"])
MockShortTermMemory().reset.assert_called_once()
assert result.output == "Short term memory has been reset.\n"
@mock.patch("crewai.cli.reset_memories_command.EntityMemory")
def test_reset_entity_memories(MockEntityMemory, runner):
result = runner.invoke(reset_memories, ["-e"])
MockEntityMemory().reset.assert_called_once()
assert result.output == "Entity memory has been reset.\n"
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
def test_reset_long_term_memories(MockLongTermMemory, runner):
result = runner.invoke(reset_memories, ["-l"])
MockLongTermMemory().reset.assert_called_once()
assert result.output == "Long term memory has been reset.\n"
@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler")
def test_reset_kickoff_outputs(MockTaskOutputStorageHandler, runner):
result = runner.invoke(reset_memories, ["-k"])
MockTaskOutputStorageHandler().reset.assert_called_once()
assert result.output == "Latest Kickoff outputs stored has been reset.\n"
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
def test_reset_multiple_memory_flags(MockShortTermMemory, MockLongTermMemory, runner):
result = runner.invoke(
reset_memories,
[
"-s",
"-l",
],
)
MockShortTermMemory().reset.assert_called_once()
MockLongTermMemory().reset.assert_called_once()
assert (
result.output
== "Long term memory has been reset.\nShort term memory has been reset.\n"
)
def test_reset_no_memory_flags(runner):
result = runner.invoke(
reset_memories,
)
assert (
result.output
== "Please specify at least one memory type to reset using the appropriate flags.\n"
)
def test_version_command(runner): def test_version_command(runner):
result = runner.invoke(version) result = runner.invoke(version)