mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Added reset memories function inside crew class (#2047)
* Added reset memories function inside crew class * Fixed typos * Refractored the code * Refactor memory reset functionality in Crew class - Improved error handling and logging for memory reset operations - Added private methods to modularize memory reset logic - Enhanced type hints and docstrings - Updated CLI reset memories command to use new Crew method - Added utility function to get crew instance in CLI utils * fix linting issues * knowledge: Add null check in reset method for storage * cli: Update memory reset tests to use Crew's reset_memories method * cli: Enhance memory reset command with improved error handling and validation --------- Co-authored-by: Lorenze Jay <lorenzejaytech@gmail.com> Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@ import subprocess
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from crewai.cli.utils import get_crew
|
||||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||||
from crewai.memory.entity.entity_memory import EntityMemory
|
from crewai.memory.entity.entity_memory import EntityMemory
|
||||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||||
@@ -30,29 +31,34 @@ def reset_memories_command(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
crew = get_crew()
|
||||||
|
if not crew:
|
||||||
|
raise ValueError("No crew found.")
|
||||||
if all:
|
if all:
|
||||||
ShortTermMemory().reset()
|
crew.reset_memories(command_type="all")
|
||||||
EntityMemory().reset()
|
|
||||||
LongTermMemory().reset()
|
|
||||||
TaskOutputStorageHandler().reset()
|
|
||||||
KnowledgeStorage().reset()
|
|
||||||
click.echo("All memories have been reset.")
|
click.echo("All memories have been reset.")
|
||||||
else:
|
return
|
||||||
if long:
|
|
||||||
LongTermMemory().reset()
|
|
||||||
click.echo("Long term memory has been reset.")
|
|
||||||
|
|
||||||
|
if not any([long, short, entity, kickoff_outputs, knowledge]):
|
||||||
|
click.echo(
|
||||||
|
"No memory type specified. Please specify at least one type to reset."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if long:
|
||||||
|
crew.reset_memories(command_type="long")
|
||||||
|
click.echo("Long term memory has been reset.")
|
||||||
if short:
|
if short:
|
||||||
ShortTermMemory().reset()
|
crew.reset_memories(command_type="short")
|
||||||
click.echo("Short term memory has been reset.")
|
click.echo("Short term memory has been reset.")
|
||||||
if entity:
|
if entity:
|
||||||
EntityMemory().reset()
|
crew.reset_memories(command_type="entity")
|
||||||
click.echo("Entity memory has been reset.")
|
click.echo("Entity memory has been reset.")
|
||||||
if kickoff_outputs:
|
if kickoff_outputs:
|
||||||
TaskOutputStorageHandler().reset()
|
crew.reset_memories(command_type="kickoff_outputs")
|
||||||
click.echo("Latest Kickoff outputs stored has been reset.")
|
click.echo("Latest Kickoff outputs stored has been reset.")
|
||||||
if knowledge:
|
if knowledge:
|
||||||
KnowledgeStorage().reset()
|
crew.reset_memories(command_type="knowledge")
|
||||||
click.echo("Knowledge has been reset.")
|
click.echo("Knowledge has been reset.")
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import tomli
|
|||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from crewai.cli.constants import ENV_VARS
|
from crewai.cli.constants import ENV_VARS
|
||||||
|
from crewai.crew import Crew
|
||||||
|
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
import tomllib
|
import tomllib
|
||||||
@@ -247,3 +248,64 @@ def write_env_file(folder_path, env_vars):
|
|||||||
with open(env_file_path, "w") as file:
|
with open(env_file_path, "w") as file:
|
||||||
for key, value in env_vars.items():
|
for key, value in env_vars.items():
|
||||||
file.write(f"{key}={value}\n")
|
file.write(f"{key}={value}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||||
|
"""Get the crew instance from the crew.py file."""
|
||||||
|
try:
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
|
||||||
|
for root, _, files in os.walk("."):
|
||||||
|
if "crew.py" in files:
|
||||||
|
crew_path = os.path.join(root, "crew.py")
|
||||||
|
try:
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"crew_module", crew_path
|
||||||
|
)
|
||||||
|
if not spec or not spec.loader:
|
||||||
|
continue
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
try:
|
||||||
|
sys.modules[spec.name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
for attr_name in dir(module):
|
||||||
|
attr = getattr(module, attr_name)
|
||||||
|
try:
|
||||||
|
if callable(attr) and hasattr(attr, "crew"):
|
||||||
|
crew_instance = attr().crew()
|
||||||
|
return crew_instance
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing attribute {attr_name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as exec_error:
|
||||||
|
print(f"Error executing module: {exec_error}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
print(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
if require:
|
||||||
|
console.print(
|
||||||
|
f"Error importing crew from {crew_path}: {str(e)}",
|
||||||
|
style="bold red",
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
if require:
|
||||||
|
console.print("No valid Crew instance found in crew.py", style="bold red")
|
||||||
|
raise SystemExit
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if require:
|
||||||
|
console.print(
|
||||||
|
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
||||||
|
)
|
||||||
|
raise SystemExit
|
||||||
|
return None
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
@@ -1147,3 +1148,80 @@ class Crew(BaseModel):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
|
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
|
||||||
|
|
||||||
|
def reset_memories(self, command_type: str) -> None:
|
||||||
|
"""Reset specific or all memories for the crew.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command_type: Type of memory to reset.
|
||||||
|
Valid options: 'long', 'short', 'entity', 'knowledge',
|
||||||
|
'kickoff_outputs', or 'all'
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If an invalid command type is provided.
|
||||||
|
RuntimeError: If memory reset operation fails.
|
||||||
|
"""
|
||||||
|
VALID_TYPES = frozenset(
|
||||||
|
["long", "short", "entity", "knowledge", "kickoff_outputs", "all"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if command_type not in VALID_TYPES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if command_type == "all":
|
||||||
|
self._reset_all_memories()
|
||||||
|
else:
|
||||||
|
self._reset_specific_memory(command_type)
|
||||||
|
|
||||||
|
self._logger.log("info", f"{command_type} memory has been reset")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to reset {command_type} memory: {str(e)}"
|
||||||
|
self._logger.log("error", error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def _reset_all_memories(self) -> None:
|
||||||
|
"""Reset all available memory systems."""
|
||||||
|
memory_systems = [
|
||||||
|
("short term", self._short_term_memory),
|
||||||
|
("entity", self._entity_memory),
|
||||||
|
("long term", self._long_term_memory),
|
||||||
|
("task output", self._task_output_handler),
|
||||||
|
("knowledge", self.knowledge),
|
||||||
|
]
|
||||||
|
|
||||||
|
for name, system in memory_systems:
|
||||||
|
if system is not None:
|
||||||
|
try:
|
||||||
|
system.reset()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||||
|
|
||||||
|
def _reset_specific_memory(self, memory_type: str) -> None:
|
||||||
|
"""Reset a specific memory system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_type: Type of memory to reset
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the specified memory system fails to reset
|
||||||
|
"""
|
||||||
|
reset_functions = {
|
||||||
|
"long": (self._long_term_memory, "long term"),
|
||||||
|
"short": (self._short_term_memory, "short term"),
|
||||||
|
"entity": (self._entity_memory, "entity"),
|
||||||
|
"knowledge": (self.knowledge, "knowledge"),
|
||||||
|
"kickoff_outputs": (self._task_output_handler, "task output"),
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_system, name = reset_functions[memory_type]
|
||||||
|
if memory_system is None:
|
||||||
|
raise RuntimeError(f"{name} memory system is not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
memory_system.reset()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||||
|
|||||||
@@ -67,3 +67,9 @@ class Knowledge(BaseModel):
|
|||||||
source.add()
|
source.add()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
if self.storage:
|
||||||
|
self.storage.reset()
|
||||||
|
else:
|
||||||
|
raise ValueError("Storage is not initialized.")
|
||||||
|
|||||||
@@ -55,72 +55,83 @@ def test_train_invalid_string_iterations(train_crew, runner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
|
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||||
@mock.patch("crewai.cli.reset_memories_command.EntityMemory")
|
def test_reset_all_memories(mock_get_crew, runner):
|
||||||
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
|
mock_crew = mock.Mock()
|
||||||
@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler")
|
mock_get_crew.return_value = mock_crew
|
||||||
def test_reset_all_memories(
|
result = runner.invoke(reset_memories, ["-a"])
|
||||||
MockTaskOutputStorageHandler,
|
|
||||||
MockLongTermMemory,
|
|
||||||
MockEntityMemory,
|
|
||||||
MockShortTermMemory,
|
|
||||||
runner,
|
|
||||||
):
|
|
||||||
result = runner.invoke(reset_memories, ["--all"])
|
|
||||||
MockShortTermMemory().reset.assert_called_once()
|
|
||||||
MockEntityMemory().reset.assert_called_once()
|
|
||||||
MockLongTermMemory().reset.assert_called_once()
|
|
||||||
MockTaskOutputStorageHandler().reset.assert_called_once()
|
|
||||||
|
|
||||||
|
mock_crew.reset_memories.assert_called_once_with(command_type="all")
|
||||||
assert result.output == "All memories have been reset.\n"
|
assert result.output == "All memories have been reset.\n"
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
|
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||||
def test_reset_short_term_memories(MockShortTermMemory, runner):
|
def test_reset_short_term_memories(mock_get_crew, runner):
|
||||||
|
mock_crew = mock.Mock()
|
||||||
|
mock_get_crew.return_value = mock_crew
|
||||||
result = runner.invoke(reset_memories, ["-s"])
|
result = runner.invoke(reset_memories, ["-s"])
|
||||||
MockShortTermMemory().reset.assert_called_once()
|
|
||||||
|
mock_crew.reset_memories.assert_called_once_with(command_type="short")
|
||||||
assert result.output == "Short term memory has been reset.\n"
|
assert result.output == "Short term memory has been reset.\n"
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("crewai.cli.reset_memories_command.EntityMemory")
|
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||||
def test_reset_entity_memories(MockEntityMemory, runner):
|
def test_reset_entity_memories(mock_get_crew, runner):
|
||||||
|
mock_crew = mock.Mock()
|
||||||
|
mock_get_crew.return_value = mock_crew
|
||||||
result = runner.invoke(reset_memories, ["-e"])
|
result = runner.invoke(reset_memories, ["-e"])
|
||||||
MockEntityMemory().reset.assert_called_once()
|
|
||||||
|
mock_crew.reset_memories.assert_called_once_with(command_type="entity")
|
||||||
assert result.output == "Entity memory has been reset.\n"
|
assert result.output == "Entity memory has been reset.\n"
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
|
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||||
def test_reset_long_term_memories(MockLongTermMemory, runner):
|
def test_reset_long_term_memories(mock_get_crew, runner):
|
||||||
|
mock_crew = mock.Mock()
|
||||||
|
mock_get_crew.return_value = mock_crew
|
||||||
result = runner.invoke(reset_memories, ["-l"])
|
result = runner.invoke(reset_memories, ["-l"])
|
||||||
MockLongTermMemory().reset.assert_called_once()
|
|
||||||
|
mock_crew.reset_memories.assert_called_once_with(command_type="long")
|
||||||
assert result.output == "Long term memory has been reset.\n"
|
assert result.output == "Long term memory has been reset.\n"
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler")
|
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||||
def test_reset_kickoff_outputs(MockTaskOutputStorageHandler, runner):
|
def test_reset_kickoff_outputs(mock_get_crew, runner):
|
||||||
|
mock_crew = mock.Mock()
|
||||||
|
mock_get_crew.return_value = mock_crew
|
||||||
result = runner.invoke(reset_memories, ["-k"])
|
result = runner.invoke(reset_memories, ["-k"])
|
||||||
MockTaskOutputStorageHandler().reset.assert_called_once()
|
|
||||||
|
mock_crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs")
|
||||||
assert result.output == "Latest Kickoff outputs stored has been reset.\n"
|
assert result.output == "Latest Kickoff outputs stored has been reset.\n"
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
|
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||||
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
|
def test_reset_multiple_memory_flags(mock_get_crew, runner):
|
||||||
def test_reset_multiple_memory_flags(MockShortTermMemory, MockLongTermMemory, runner):
|
mock_crew = mock.Mock()
|
||||||
result = runner.invoke(
|
mock_get_crew.return_value = mock_crew
|
||||||
reset_memories,
|
result = runner.invoke(reset_memories, ["-s", "-l"])
|
||||||
[
|
|
||||||
"-s",
|
# Check that reset_memories was called twice with the correct arguments
|
||||||
"-l",
|
assert mock_crew.reset_memories.call_count == 2
|
||||||
],
|
mock_crew.reset_memories.assert_has_calls(
|
||||||
|
[mock.call(command_type="long"), mock.call(command_type="short")]
|
||||||
)
|
)
|
||||||
MockShortTermMemory().reset.assert_called_once()
|
|
||||||
MockLongTermMemory().reset.assert_called_once()
|
|
||||||
assert (
|
assert (
|
||||||
result.output
|
result.output
|
||||||
== "Long term memory has been reset.\nShort term memory has been reset.\n"
|
== "Long term memory has been reset.\nShort term memory has been reset.\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||||
|
def test_reset_knowledge(mock_get_crew, runner):
|
||||||
|
mock_crew = mock.Mock()
|
||||||
|
mock_get_crew.return_value = mock_crew
|
||||||
|
result = runner.invoke(reset_memories, ["--knowledge"])
|
||||||
|
|
||||||
|
mock_crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||||
|
assert result.output == "Knowledge has been reset.\n"
|
||||||
|
|
||||||
|
|
||||||
def test_reset_no_memory_flags(runner):
|
def test_reset_no_memory_flags(runner):
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
reset_memories,
|
reset_memories,
|
||||||
|
|||||||
Reference in New Issue
Block a user