mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-25 16:18:13 +00:00
Reset memory (#958)
* reseting memory on cli * using storage.reset * deleting memories on command * added tests * handle when no flags are used * added docs
This commit is contained in:
@@ -161,6 +161,39 @@ my_crew = Crew(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Resetting Memory
|
||||||
|
```sh
|
||||||
|
crewai reset_memories [OPTIONS]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Resetting Memory Options
|
||||||
|
- **`-l, --long`**
|
||||||
|
- **Description:** Reset LONG TERM memory.
|
||||||
|
- **Type:** Flag (boolean)
|
||||||
|
- **Default:** False
|
||||||
|
|
||||||
|
- **`-s, --short`**
|
||||||
|
- **Description:** Reset SHORT TERM memory.
|
||||||
|
- **Type:** Flag (boolean)
|
||||||
|
- **Default:** False
|
||||||
|
|
||||||
|
- **`-e, --entities`**
|
||||||
|
- **Description:** Reset ENTITIES memory.
|
||||||
|
- **Type:** Flag (boolean)
|
||||||
|
- **Default:** False
|
||||||
|
|
||||||
|
- **`-k, --kickoff-outputs`**
|
||||||
|
- **Description:** Reset LATEST KICKOFF TASK OUTPUTS.
|
||||||
|
- **Type:** Flag (boolean)
|
||||||
|
- **Default:** False
|
||||||
|
|
||||||
|
- **`-a, --all`**
|
||||||
|
- **Description:** Reset ALL memories.
|
||||||
|
- **Type:** Flag (boolean)
|
||||||
|
- **Default:** False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Benefits of Using crewAI's Memory System
|
## Benefits of Using crewAI's Memory System
|
||||||
- **Adaptive Learning:** Crews become more efficient over time, adapting to new information and refining their approach to tasks.
|
- **Adaptive Learning:** Crews become more efficient over time, adapting to new information and refining their approach to tasks.
|
||||||
- **Enhanced Personalization:** Memory enables agents to remember user preferences and historical interactions, leading to personalized experiences.
|
- **Enhanced Personalization:** Memory enables agents to remember user preferences and historical interactions, leading to personalized experiences.
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
|||||||
from .create_crew import create_crew
|
from .create_crew import create_crew
|
||||||
from .train_crew import train_crew
|
from .train_crew import train_crew
|
||||||
from .replay_from_task import replay_task_command
|
from .replay_from_task import replay_task_command
|
||||||
|
from .reset_memories_command import reset_memories_command
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@@ -99,5 +100,31 @@ def log_tasks_outputs() -> None:
|
|||||||
click.echo(f"An error occurred while logging task outputs: {e}", err=True)
|
click.echo(f"An error occurred while logging task outputs: {e}", err=True)
|
||||||
|
|
||||||
|
|
||||||
|
@crewai.command()
|
||||||
|
@click.option("-l", "--long", is_flag=True, help="Reset LONG TERM memory")
|
||||||
|
@click.option("-s", "--short", is_flag=True, help="Reset SHORT TERM memory")
|
||||||
|
@click.option("-e", "--entities", is_flag=True, help="Reset ENTITIES memory")
|
||||||
|
@click.option(
|
||||||
|
"-k",
|
||||||
|
"--kickoff-outputs",
|
||||||
|
is_flag=True,
|
||||||
|
help="Reset LATEST KICKOFF TASK OUTPUTS",
|
||||||
|
)
|
||||||
|
@click.option("-a", "--all", is_flag=True, help="Reset ALL memories")
|
||||||
|
def reset_memories(long, short, entities, kickoff_outputs, all):
|
||||||
|
"""
|
||||||
|
Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not all and not (long or short or entities or kickoff_outputs):
|
||||||
|
click.echo(
|
||||||
|
"Please specify at least one memory type to reset using the appropriate flags."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
reset_memories_command(long, short, entities, kickoff_outputs, all)
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"An error occurred while resetting memories: {e}", err=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
crewai()
|
crewai()
|
||||||
|
|||||||
45
src/crewai/cli/reset_memories_command.py
Normal file
45
src/crewai/cli/reset_memories_command.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import subprocess
|
||||||
|
import click
|
||||||
|
|
||||||
|
from crewai.memory.entity.entity_memory import EntityMemory
|
||||||
|
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||||
|
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||||
|
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||||
|
|
||||||
|
|
||||||
|
def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None:
|
||||||
|
"""
|
||||||
|
Replay the crew execution from a specific task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (str): The ID of the task to replay from.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if all:
|
||||||
|
ShortTermMemory().reset()
|
||||||
|
EntityMemory().reset()
|
||||||
|
LongTermMemory().reset()
|
||||||
|
TaskOutputStorageHandler().reset()
|
||||||
|
click.echo("All memories have been reset.")
|
||||||
|
else:
|
||||||
|
if long:
|
||||||
|
LongTermMemory().reset()
|
||||||
|
click.echo("Long term memory has been reset.")
|
||||||
|
|
||||||
|
if short:
|
||||||
|
ShortTermMemory().reset()
|
||||||
|
click.echo("Short term memory has been reset.")
|
||||||
|
if entity:
|
||||||
|
EntityMemory().reset()
|
||||||
|
click.echo("Entity memory has been reset.")
|
||||||
|
if kickoff_outputs:
|
||||||
|
TaskOutputStorageHandler().reset()
|
||||||
|
click.echo("Latest Kickoff outputs stored has been reset.")
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||||
|
click.echo(e.output, err=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||||
@@ -23,3 +23,9 @@ class EntityMemory(Memory):
|
|||||||
"""Saves an entity item into the SQLite storage."""
|
"""Saves an entity item into the SQLite storage."""
|
||||||
data = f"{item.name}({item.type}): {item.description}"
|
data = f"{item.name}({item.type}): {item.description}"
|
||||||
super().save(data, item.metadata)
|
super().save(data, item.metadata)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
try:
|
||||||
|
self.storage.reset()
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"An error occurred while resetting the entity memory: {e}")
|
||||||
|
|||||||
@@ -30,3 +30,6 @@ class LongTermMemory(Memory):
|
|||||||
|
|
||||||
def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]:
|
def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]:
|
||||||
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.storage.reset()
|
||||||
|
|||||||
@@ -18,8 +18,16 @@ class ShortTermMemory(Memory):
|
|||||||
)
|
)
|
||||||
super().__init__(storage)
|
super().__init__(storage)
|
||||||
|
|
||||||
def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
def save(self, item: ShortTermMemoryItem) -> None:
|
||||||
super().save(item.data, item.metadata, item.agent)
|
super().save(item.data, item.metadata, item.agent)
|
||||||
|
|
||||||
def search(self, query: str, score_threshold: float = 0.35):
|
def search(self, query: str, score_threshold: float = 0.35):
|
||||||
return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
try:
|
||||||
|
self.storage.reset()
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"An error occurred while resetting the short-term memory: {e}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -9,3 +9,6 @@ class Storage:
|
|||||||
|
|
||||||
def search(self, key: str) -> Dict[str, Any]: # type: ignore
|
def search(self, key: str) -> Dict[str, Any]: # type: ignore
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -103,3 +103,20 @@ class LTMSQLiteStorage:
|
|||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""Resets the LTM table with error handling."""
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM long_term_memories")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import contextlib
|
|||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
@@ -71,13 +72,13 @@ class RAGStorage(Storage):
|
|||||||
|
|
||||||
if embedder_config:
|
if embedder_config:
|
||||||
config["embedder"] = embedder_config
|
config["embedder"] = embedder_config
|
||||||
|
self.type = type
|
||||||
self.app = App.from_config(config=config)
|
self.app = App.from_config(config=config)
|
||||||
self.app.llm = FakeLLM()
|
self.app.llm = FakeLLM()
|
||||||
if allow_reset:
|
if allow_reset:
|
||||||
self.app.reset()
|
self.app.reset()
|
||||||
|
|
||||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None: # type: ignore # BUG?: Should be save(key, value, metadata) Signature of "save" incompatible with supertype "Storage"
|
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||||
self._generate_embedding(value, metadata)
|
self._generate_embedding(value, metadata)
|
||||||
|
|
||||||
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
|
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
|
||||||
@@ -102,3 +103,11 @@ class RAGStorage(Storage):
|
|||||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
|
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
|
||||||
with suppress_logging():
|
with suppress_logging():
|
||||||
self.app.add(text, data_type="text", metadata=metadata)
|
self.app.add(text, data_type="text", metadata=metadata)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
try:
|
||||||
|
shutil.rmtree(f"{db_storage_path()}/{self.type}")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from unittest import mock
|
|||||||
import pytest
|
import pytest
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
|
||||||
from crewai.cli.cli import train, version
|
from crewai.cli.cli import train, version, reset_memories
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -41,6 +41,82 @@ def test_train_invalid_string_iterations(train_crew, runner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
|
||||||
|
@mock.patch("crewai.cli.reset_memories_command.EntityMemory")
|
||||||
|
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
|
||||||
|
@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler")
|
||||||
|
def test_reset_all_memories(
|
||||||
|
MockTaskOutputStorageHandler,
|
||||||
|
MockLongTermMemory,
|
||||||
|
MockEntityMemory,
|
||||||
|
MockShortTermMemory,
|
||||||
|
runner,
|
||||||
|
):
|
||||||
|
result = runner.invoke(reset_memories, ["--all"])
|
||||||
|
MockShortTermMemory().reset.assert_called_once()
|
||||||
|
MockEntityMemory().reset.assert_called_once()
|
||||||
|
MockLongTermMemory().reset.assert_called_once()
|
||||||
|
MockTaskOutputStorageHandler().reset.assert_called_once()
|
||||||
|
|
||||||
|
assert result.output == "All memories have been reset.\n"
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
|
||||||
|
def test_reset_short_term_memories(MockShortTermMemory, runner):
|
||||||
|
result = runner.invoke(reset_memories, ["-s"])
|
||||||
|
MockShortTermMemory().reset.assert_called_once()
|
||||||
|
assert result.output == "Short term memory has been reset.\n"
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("crewai.cli.reset_memories_command.EntityMemory")
|
||||||
|
def test_reset_entity_memories(MockEntityMemory, runner):
|
||||||
|
result = runner.invoke(reset_memories, ["-e"])
|
||||||
|
MockEntityMemory().reset.assert_called_once()
|
||||||
|
assert result.output == "Entity memory has been reset.\n"
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
|
||||||
|
def test_reset_long_term_memories(MockLongTermMemory, runner):
|
||||||
|
result = runner.invoke(reset_memories, ["-l"])
|
||||||
|
MockLongTermMemory().reset.assert_called_once()
|
||||||
|
assert result.output == "Long term memory has been reset.\n"
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler")
|
||||||
|
def test_reset_kickoff_outputs(MockTaskOutputStorageHandler, runner):
|
||||||
|
result = runner.invoke(reset_memories, ["-k"])
|
||||||
|
MockTaskOutputStorageHandler().reset.assert_called_once()
|
||||||
|
assert result.output == "Latest Kickoff outputs stored has been reset.\n"
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
|
||||||
|
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
|
||||||
|
def test_reset_multiple_memory_flags(MockShortTermMemory, MockLongTermMemory, runner):
|
||||||
|
result = runner.invoke(
|
||||||
|
reset_memories,
|
||||||
|
[
|
||||||
|
"-s",
|
||||||
|
"-l",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
MockShortTermMemory().reset.assert_called_once()
|
||||||
|
MockLongTermMemory().reset.assert_called_once()
|
||||||
|
assert (
|
||||||
|
result.output
|
||||||
|
== "Long term memory has been reset.\nShort term memory has been reset.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_no_memory_flags(runner):
|
||||||
|
result = runner.invoke(
|
||||||
|
reset_memories,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
result.output
|
||||||
|
== "Please specify at least one memory type to reset using the appropriate flags.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_version_command(runner):
|
def test_version_command(runner):
|
||||||
result = runner.invoke(version)
|
result = runner.invoke(version)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user