From be1b9a399408ef1a0edf3454282d31cf88489533 Mon Sep 17 00:00:00 2001 From: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Date: Thu, 18 Jul 2024 09:29:42 -0700 Subject: [PATCH] Reset memory (#958) * reseting memory on cli * using storage.reset * deleting memories on command * added tests * handle when no flags are used * added docs --- docs/core-concepts/Memory.md | 33 ++++++++ src/crewai/cli/cli.py | 27 +++++++ src/crewai/cli/reset_memories_command.py | 45 +++++++++++ src/crewai/memory/entity/entity_memory.py | 6 ++ .../memory/long_term/long_term_memory.py | 3 + .../memory/short_term/short_term_memory.py | 10 ++- src/crewai/memory/storage/interface.py | 3 + .../memory/storage/ltm_sqlite_storage.py | 17 ++++ src/crewai/memory/storage/rag_storage.py | 13 +++- tests/cli/cli_test.py | 78 ++++++++++++++++++- 10 files changed, 231 insertions(+), 4 deletions(-) create mode 100644 src/crewai/cli/reset_memories_command.py diff --git a/docs/core-concepts/Memory.md b/docs/core-concepts/Memory.md index 9c9d47be0..0d1d3a67d 100644 --- a/docs/core-concepts/Memory.md +++ b/docs/core-concepts/Memory.md @@ -161,6 +161,39 @@ my_crew = Crew( ) ``` +### Resetting Memory +```sh +crewai reset_memories [OPTIONS] +``` + +#### Resetting Memory Options +- **`-l, --long`** + - **Description:** Reset LONG TERM memory. + - **Type:** Flag (boolean) + - **Default:** False + +- **`-s, --short`** + - **Description:** Reset SHORT TERM memory. + - **Type:** Flag (boolean) + - **Default:** False + +- **`-e, --entities`** + - **Description:** Reset ENTITIES memory. + - **Type:** Flag (boolean) + - **Default:** False + +- **`-k, --kickoff-outputs`** + - **Description:** Reset LATEST KICKOFF TASK OUTPUTS. + - **Type:** Flag (boolean) + - **Default:** False + +- **`-a, --all`** + - **Description:** Reset ALL memories. + - **Type:** Flag (boolean) + - **Default:** False + + + ## Benefits of Using crewAI's Memory System - **Adaptive Learning:** Crews become more efficient over time, adapting to new information and refining their approach to tasks. - **Enhanced Personalization:** Memory enables agents to remember user preferences and historical interactions, leading to personalized experiences. diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 6b8038248..7b025a567 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -9,6 +9,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import ( from .create_crew import create_crew from .train_crew import train_crew from .replay_from_task import replay_task_command +from .reset_memories_command import reset_memories_command @click.group() @@ -99,5 +100,31 @@ def log_tasks_outputs() -> None: click.echo(f"An error occurred while logging task outputs: {e}", err=True) +@crewai.command() +@click.option("-l", "--long", is_flag=True, help="Reset LONG TERM memory") +@click.option("-s", "--short", is_flag=True, help="Reset SHORT TERM memory") +@click.option("-e", "--entities", is_flag=True, help="Reset ENTITIES memory") +@click.option( + "-k", + "--kickoff-outputs", + is_flag=True, + help="Reset LATEST KICKOFF TASK OUTPUTS", +) +@click.option("-a", "--all", is_flag=True, help="Reset ALL memories") +def reset_memories(long, short, entities, kickoff_outputs, all): + """ + Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved. + """ + try: + if not all and not (long or short or entities or kickoff_outputs): + click.echo( + "Please specify at least one memory type to reset using the appropriate flags." + ) + return + reset_memories_command(long, short, entities, kickoff_outputs, all) + except Exception as e: + click.echo(f"An error occurred while resetting memories: {e}", err=True) + + if __name__ == "__main__": crewai() diff --git a/src/crewai/cli/reset_memories_command.py b/src/crewai/cli/reset_memories_command.py new file mode 100644 index 000000000..68d82a92a --- /dev/null +++ b/src/crewai/cli/reset_memories_command.py @@ -0,0 +1,45 @@ +import subprocess +import click + +from crewai.memory.entity.entity_memory import EntityMemory +from crewai.memory.long_term.long_term_memory import LongTermMemory +from crewai.memory.short_term.short_term_memory import ShortTermMemory +from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler + + +def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None: + """ + Replay the crew execution from a specific task. + + Args: + task_id (str): The ID of the task to replay from. + """ + + try: + if all: + ShortTermMemory().reset() + EntityMemory().reset() + LongTermMemory().reset() + TaskOutputStorageHandler().reset() + click.echo("All memories have been reset.") + else: + if long: + LongTermMemory().reset() + click.echo("Long term memory has been reset.") + + if short: + ShortTermMemory().reset() + click.echo("Short term memory has been reset.") + if entity: + EntityMemory().reset() + click.echo("Entity memory has been reset.") + if kickoff_outputs: + TaskOutputStorageHandler().reset() + click.echo("Latest Kickoff outputs stored has been reset.") + + except subprocess.CalledProcessError as e: + click.echo(f"An error occurred while resetting the memories: {e}", err=True) + click.echo(e.output, err=True) + + except Exception as e: + click.echo(f"An unexpected error occurred: {e}", err=True) diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 519d7a62a..50aaeeaab 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -23,3 +23,9 @@ class EntityMemory(Memory): """Saves an entity item into the SQLite storage.""" data = f"{item.name}({item.type}): {item.description}" super().save(data, item.metadata) + + def reset(self) -> None: + try: + self.storage.reset() + except Exception as e: + raise Exception(f"An error occurred while resetting the entity memory: {e}") diff --git a/src/crewai/memory/long_term/long_term_memory.py b/src/crewai/memory/long_term/long_term_memory.py index b437d9e7d..041268107 100644 --- a/src/crewai/memory/long_term/long_term_memory.py +++ b/src/crewai/memory/long_term/long_term_memory.py @@ -30,3 +30,6 @@ class LongTermMemory(Memory): def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]: return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load" + + def reset(self) -> None: + self.storage.reset() diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index e9410ebbc..ec65ad6f6 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -18,8 +18,16 @@ class ShortTermMemory(Memory): ) super().__init__(storage) - def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" + def save(self, item: ShortTermMemoryItem) -> None: super().save(item.data, item.metadata, item.agent) def search(self, query: str, score_threshold: float = 0.35): return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters + + def reset(self) -> None: + try: + self.storage.reset() + except Exception as e: + raise Exception( + f"An error occurred while resetting the short-term memory: {e}" + ) diff --git a/src/crewai/memory/storage/interface.py b/src/crewai/memory/storage/interface.py index 75c5a7b7a..e988862ba 100644 --- a/src/crewai/memory/storage/interface.py +++ b/src/crewai/memory/storage/interface.py @@ -9,3 +9,6 @@ class Storage: def search(self, key: str) -> Dict[str, Any]: # type: ignore pass + + def reset(self) -> None: + pass diff --git a/src/crewai/memory/storage/ltm_sqlite_storage.py b/src/crewai/memory/storage/ltm_sqlite_storage.py index 9825abad0..7fb388a62 100644 --- a/src/crewai/memory/storage/ltm_sqlite_storage.py +++ b/src/crewai/memory/storage/ltm_sqlite_storage.py @@ -103,3 +103,20 @@ class LTMSQLiteStorage: color="red", ) return None + + def reset( + self, + ) -> None: + """Resets the LTM table with error handling.""" + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM long_term_memories") + conn.commit() + + except sqlite3.Error as e: + self._printer.print( + content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}", + color="red", + ) + return None diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 9a332ebfd..e53f096e9 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -2,6 +2,7 @@ import contextlib import io import logging import os +import shutil from typing import Any, Dict, List, Optional from embedchain import App @@ -71,13 +72,13 @@ class RAGStorage(Storage): if embedder_config: config["embedder"] = embedder_config - + self.type = type self.app = App.from_config(config=config) self.app.llm = FakeLLM() if allow_reset: self.app.reset() - def save(self, value: Any, metadata: Dict[str, Any]) -> None: # type: ignore # BUG?: Should be save(key, value, metadata) Signature of "save" incompatible with supertype "Storage" + def save(self, value: Any, metadata: Dict[str, Any]) -> None: self._generate_embedding(value, metadata) def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage" @@ -102,3 +103,11 @@ class RAGStorage(Storage): def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any: with suppress_logging(): self.app.add(text, data_type="text", metadata=metadata) + + def reset(self) -> None: + try: + shutil.rmtree(f"{db_storage_path()}/{self.type}") + except Exception as e: + raise Exception( + f"An error occurred while resetting the {self.type} memory: {e}" + ) diff --git a/tests/cli/cli_test.py b/tests/cli/cli_test.py index f2c6879b0..f877f913a 100644 --- a/tests/cli/cli_test.py +++ b/tests/cli/cli_test.py @@ -3,7 +3,7 @@ from unittest import mock import pytest from click.testing import CliRunner -from crewai.cli.cli import train, version +from crewai.cli.cli import train, version, reset_memories @pytest.fixture @@ -41,6 +41,82 @@ def test_train_invalid_string_iterations(train_crew, runner): ) +@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory") +@mock.patch("crewai.cli.reset_memories_command.EntityMemory") +@mock.patch("crewai.cli.reset_memories_command.LongTermMemory") +@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler") +def test_reset_all_memories( + MockTaskOutputStorageHandler, + MockLongTermMemory, + MockEntityMemory, + MockShortTermMemory, + runner, +): + result = runner.invoke(reset_memories, ["--all"]) + MockShortTermMemory().reset.assert_called_once() + MockEntityMemory().reset.assert_called_once() + MockLongTermMemory().reset.assert_called_once() + MockTaskOutputStorageHandler().reset.assert_called_once() + + assert result.output == "All memories have been reset.\n" + + +@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory") +def test_reset_short_term_memories(MockShortTermMemory, runner): + result = runner.invoke(reset_memories, ["-s"]) + MockShortTermMemory().reset.assert_called_once() + assert result.output == "Short term memory has been reset.\n" + + +@mock.patch("crewai.cli.reset_memories_command.EntityMemory") +def test_reset_entity_memories(MockEntityMemory, runner): + result = runner.invoke(reset_memories, ["-e"]) + MockEntityMemory().reset.assert_called_once() + assert result.output == "Entity memory has been reset.\n" + + +@mock.patch("crewai.cli.reset_memories_command.LongTermMemory") +def test_reset_long_term_memories(MockLongTermMemory, runner): + result = runner.invoke(reset_memories, ["-l"]) + MockLongTermMemory().reset.assert_called_once() + assert result.output == "Long term memory has been reset.\n" + + +@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler") +def test_reset_kickoff_outputs(MockTaskOutputStorageHandler, runner): + result = runner.invoke(reset_memories, ["-k"]) + MockTaskOutputStorageHandler().reset.assert_called_once() + assert result.output == "Latest Kickoff outputs stored has been reset.\n" + + +@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory") +@mock.patch("crewai.cli.reset_memories_command.LongTermMemory") +def test_reset_multiple_memory_flags(MockShortTermMemory, MockLongTermMemory, runner): + result = runner.invoke( + reset_memories, + [ + "-s", + "-l", + ], + ) + MockShortTermMemory().reset.assert_called_once() + MockLongTermMemory().reset.assert_called_once() + assert ( + result.output + == "Long term memory has been reset.\nShort term memory has been reset.\n" + ) + + +def test_reset_no_memory_flags(runner): + result = runner.invoke( + reset_memories, + ) + assert ( + result.output + == "Please specify at least one memory type to reset using the appropriate flags.\n" + ) + + def test_version_command(runner): result = runner.invoke(version)