From 22ce67a8e799bb63405d68f7cc381eccd190608f Mon Sep 17 00:00:00 2001 From: Vidit Ostwal <110953813+Vidit-Ostwal@users.noreply.github.com> Date: Fri, 7 Feb 2025 23:19:25 +0530 Subject: [PATCH] Added reset memories function inside crew class (#2047) * Added reset memories function inside crew class * Fixed typos * Refractored the code * Refactor memory reset functionality in Crew class - Improved error handling and logging for memory reset operations - Added private methods to modularize memory reset logic - Enhanced type hints and docstrings - Updated CLI reset memories command to use new Crew method - Added utility function to get crew instance in CLI utils * fix linting issues * knowledge: Add null check in reset method for storage * cli: Update memory reset tests to use Crew's reset_memories method * cli: Enhance memory reset command with improved error handling and validation --------- Co-authored-by: Lorenze Jay Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com> --- src/crewai/cli/reset_memories_command.py | 48 +++++++------ src/crewai/cli/utils.py | 62 +++++++++++++++++ src/crewai/crew.py | 78 +++++++++++++++++++++ src/crewai/knowledge/knowledge.py | 6 ++ tests/cli/cli_test.py | 89 +++++++++++++----------- 5 files changed, 223 insertions(+), 60 deletions(-) diff --git a/src/crewai/cli/reset_memories_command.py b/src/crewai/cli/reset_memories_command.py index 554232f52..4f7f1beb6 100644 --- a/src/crewai/cli/reset_memories_command.py +++ b/src/crewai/cli/reset_memories_command.py @@ -2,6 +2,7 @@ import subprocess import click +from crewai.cli.utils import get_crew from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.long_term.long_term_memory import LongTermMemory @@ -30,30 +31,35 @@ def reset_memories_command( """ try: + crew = get_crew() + if not crew: + raise ValueError("No crew found.") if all: - ShortTermMemory().reset() - EntityMemory().reset() - LongTermMemory().reset() - TaskOutputStorageHandler().reset() - KnowledgeStorage().reset() + crew.reset_memories(command_type="all") click.echo("All memories have been reset.") - else: - if long: - LongTermMemory().reset() - click.echo("Long term memory has been reset.") + return - if short: - ShortTermMemory().reset() - click.echo("Short term memory has been reset.") - if entity: - EntityMemory().reset() - click.echo("Entity memory has been reset.") - if kickoff_outputs: - TaskOutputStorageHandler().reset() - click.echo("Latest Kickoff outputs stored has been reset.") - if knowledge: - KnowledgeStorage().reset() - click.echo("Knowledge has been reset.") + if not any([long, short, entity, kickoff_outputs, knowledge]): + click.echo( + "No memory type specified. Please specify at least one type to reset." + ) + return + + if long: + crew.reset_memories(command_type="long") + click.echo("Long term memory has been reset.") + if short: + crew.reset_memories(command_type="short") + click.echo("Short term memory has been reset.") + if entity: + crew.reset_memories(command_type="entity") + click.echo("Entity memory has been reset.") + if kickoff_outputs: + crew.reset_memories(command_type="kickoff_outputs") + click.echo("Latest Kickoff outputs stored has been reset.") + if knowledge: + crew.reset_memories(command_type="knowledge") + click.echo("Knowledge has been reset.") except subprocess.CalledProcessError as e: click.echo(f"An error occurred while resetting the memories: {e}", err=True) diff --git a/src/crewai/cli/utils.py b/src/crewai/cli/utils.py index a385e1f37..60eb2488a 100644 --- a/src/crewai/cli/utils.py +++ b/src/crewai/cli/utils.py @@ -9,6 +9,7 @@ import tomli from rich.console import Console from crewai.cli.constants import ENV_VARS +from crewai.crew import Crew if sys.version_info >= (3, 11): import tomllib @@ -247,3 +248,64 @@ def write_env_file(folder_path, env_vars): with open(env_file_path, "w") as file: for key, value in env_vars.items(): file.write(f"{key}={value}\n") + + +def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None: + """Get the crew instance from the crew.py file.""" + try: + import importlib.util + import os + + for root, _, files in os.walk("."): + if "crew.py" in files: + crew_path = os.path.join(root, "crew.py") + try: + spec = importlib.util.spec_from_file_location( + "crew_module", crew_path + ) + if not spec or not spec.loader: + continue + module = importlib.util.module_from_spec(spec) + try: + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + for attr_name in dir(module): + attr = getattr(module, attr_name) + try: + if callable(attr) and hasattr(attr, "crew"): + crew_instance = attr().crew() + return crew_instance + + except Exception as e: + print(f"Error processing attribute {attr_name}: {e}") + continue + + except Exception as exec_error: + print(f"Error executing module: {exec_error}") + import traceback + + print(f"Traceback: {traceback.format_exc()}") + + except (ImportError, AttributeError) as e: + if require: + console.print( + f"Error importing crew from {crew_path}: {str(e)}", + style="bold red", + ) + continue + + break + + if require: + console.print("No valid Crew instance found in crew.py", style="bold red") + raise SystemExit + return None + + except Exception as e: + if require: + console.print( + f"Unexpected error while loading crew: {str(e)}", style="bold red" + ) + raise SystemExit + return None diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 6ec5520a0..6b500e097 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -1,6 +1,7 @@ import asyncio import json import re +import sys import uuid import warnings from concurrent.futures import Future @@ -1147,3 +1148,80 @@ class Crew(BaseModel): def __repr__(self): return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})" + + def reset_memories(self, command_type: str) -> None: + """Reset specific or all memories for the crew. + + Args: + command_type: Type of memory to reset. + Valid options: 'long', 'short', 'entity', 'knowledge', + 'kickoff_outputs', or 'all' + + Raises: + ValueError: If an invalid command type is provided. + RuntimeError: If memory reset operation fails. + """ + VALID_TYPES = frozenset( + ["long", "short", "entity", "knowledge", "kickoff_outputs", "all"] + ) + + if command_type not in VALID_TYPES: + raise ValueError( + f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}" + ) + + try: + if command_type == "all": + self._reset_all_memories() + else: + self._reset_specific_memory(command_type) + + self._logger.log("info", f"{command_type} memory has been reset") + + except Exception as e: + error_msg = f"Failed to reset {command_type} memory: {str(e)}" + self._logger.log("error", error_msg) + raise RuntimeError(error_msg) from e + + def _reset_all_memories(self) -> None: + """Reset all available memory systems.""" + memory_systems = [ + ("short term", self._short_term_memory), + ("entity", self._entity_memory), + ("long term", self._long_term_memory), + ("task output", self._task_output_handler), + ("knowledge", self.knowledge), + ] + + for name, system in memory_systems: + if system is not None: + try: + system.reset() + except Exception as e: + raise RuntimeError(f"Failed to reset {name} memory") from e + + def _reset_specific_memory(self, memory_type: str) -> None: + """Reset a specific memory system. + + Args: + memory_type: Type of memory to reset + + Raises: + RuntimeError: If the specified memory system fails to reset + """ + reset_functions = { + "long": (self._long_term_memory, "long term"), + "short": (self._short_term_memory, "short term"), + "entity": (self._entity_memory, "entity"), + "knowledge": (self.knowledge, "knowledge"), + "kickoff_outputs": (self._task_output_handler, "task output"), + } + + memory_system, name = reset_functions[memory_type] + if memory_system is None: + raise RuntimeError(f"{name} memory system is not initialized") + + try: + memory_system.reset() + except Exception as e: + raise RuntimeError(f"Failed to reset {name} memory") from e diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index d1d4ede6c..da1db90a8 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -67,3 +67,9 @@ class Knowledge(BaseModel): source.add() except Exception as e: raise e + + def reset(self) -> None: + if self.storage: + self.storage.reset() + else: + raise ValueError("Storage is not initialized.") diff --git a/tests/cli/cli_test.py b/tests/cli/cli_test.py index 15ed81637..dc0c502b7 100644 --- a/tests/cli/cli_test.py +++ b/tests/cli/cli_test.py @@ -55,72 +55,83 @@ def test_train_invalid_string_iterations(train_crew, runner): ) -@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory") -@mock.patch("crewai.cli.reset_memories_command.EntityMemory") -@mock.patch("crewai.cli.reset_memories_command.LongTermMemory") -@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler") -def test_reset_all_memories( - MockTaskOutputStorageHandler, - MockLongTermMemory, - MockEntityMemory, - MockShortTermMemory, - runner, -): - result = runner.invoke(reset_memories, ["--all"]) - MockShortTermMemory().reset.assert_called_once() - MockEntityMemory().reset.assert_called_once() - MockLongTermMemory().reset.assert_called_once() - MockTaskOutputStorageHandler().reset.assert_called_once() +@mock.patch("crewai.cli.reset_memories_command.get_crew") +def test_reset_all_memories(mock_get_crew, runner): + mock_crew = mock.Mock() + mock_get_crew.return_value = mock_crew + result = runner.invoke(reset_memories, ["-a"]) + mock_crew.reset_memories.assert_called_once_with(command_type="all") assert result.output == "All memories have been reset.\n" -@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory") -def test_reset_short_term_memories(MockShortTermMemory, runner): +@mock.patch("crewai.cli.reset_memories_command.get_crew") +def test_reset_short_term_memories(mock_get_crew, runner): + mock_crew = mock.Mock() + mock_get_crew.return_value = mock_crew result = runner.invoke(reset_memories, ["-s"]) - MockShortTermMemory().reset.assert_called_once() + + mock_crew.reset_memories.assert_called_once_with(command_type="short") assert result.output == "Short term memory has been reset.\n" -@mock.patch("crewai.cli.reset_memories_command.EntityMemory") -def test_reset_entity_memories(MockEntityMemory, runner): +@mock.patch("crewai.cli.reset_memories_command.get_crew") +def test_reset_entity_memories(mock_get_crew, runner): + mock_crew = mock.Mock() + mock_get_crew.return_value = mock_crew result = runner.invoke(reset_memories, ["-e"]) - MockEntityMemory().reset.assert_called_once() + + mock_crew.reset_memories.assert_called_once_with(command_type="entity") assert result.output == "Entity memory has been reset.\n" -@mock.patch("crewai.cli.reset_memories_command.LongTermMemory") -def test_reset_long_term_memories(MockLongTermMemory, runner): +@mock.patch("crewai.cli.reset_memories_command.get_crew") +def test_reset_long_term_memories(mock_get_crew, runner): + mock_crew = mock.Mock() + mock_get_crew.return_value = mock_crew result = runner.invoke(reset_memories, ["-l"]) - MockLongTermMemory().reset.assert_called_once() + + mock_crew.reset_memories.assert_called_once_with(command_type="long") assert result.output == "Long term memory has been reset.\n" -@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler") -def test_reset_kickoff_outputs(MockTaskOutputStorageHandler, runner): +@mock.patch("crewai.cli.reset_memories_command.get_crew") +def test_reset_kickoff_outputs(mock_get_crew, runner): + mock_crew = mock.Mock() + mock_get_crew.return_value = mock_crew result = runner.invoke(reset_memories, ["-k"]) - MockTaskOutputStorageHandler().reset.assert_called_once() + + mock_crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs") assert result.output == "Latest Kickoff outputs stored has been reset.\n" -@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory") -@mock.patch("crewai.cli.reset_memories_command.LongTermMemory") -def test_reset_multiple_memory_flags(MockShortTermMemory, MockLongTermMemory, runner): - result = runner.invoke( - reset_memories, - [ - "-s", - "-l", - ], +@mock.patch("crewai.cli.reset_memories_command.get_crew") +def test_reset_multiple_memory_flags(mock_get_crew, runner): + mock_crew = mock.Mock() + mock_get_crew.return_value = mock_crew + result = runner.invoke(reset_memories, ["-s", "-l"]) + + # Check that reset_memories was called twice with the correct arguments + assert mock_crew.reset_memories.call_count == 2 + mock_crew.reset_memories.assert_has_calls( + [mock.call(command_type="long"), mock.call(command_type="short")] ) - MockShortTermMemory().reset.assert_called_once() - MockLongTermMemory().reset.assert_called_once() assert ( result.output == "Long term memory has been reset.\nShort term memory has been reset.\n" ) +@mock.patch("crewai.cli.reset_memories_command.get_crew") +def test_reset_knowledge(mock_get_crew, runner): + mock_crew = mock.Mock() + mock_get_crew.return_value = mock_crew + result = runner.invoke(reset_memories, ["--knowledge"]) + + mock_crew.reset_memories.assert_called_once_with(command_type="knowledge") + assert result.output == "Knowledge has been reset.\n" + + def test_reset_no_memory_flags(runner): result = runner.invoke( reset_memories,