From f89c2bfb7e5479a07df32e0b8eb593c2f40cdd3e Mon Sep 17 00:00:00 2001 From: Lucas Gomide Date: Fri, 2 May 2025 13:40:42 -0300 Subject: [PATCH] Fix crewai reset-memories when Embedding dimension mismatch (#2737) * fix: support to reset memories after changing Crew's embedder The sources must not be added while initializing the Knowledge otherwise we could not reset it * chore: improve reset memory feedback Previously, even when no memories were actually erased, we logged that they had been. From now on, the log will specify which memory has been reset. * feat: improve get_crew discovery from a single file Crew instances can now be discovered from any function or method with a return type annotation of -> Crew, as well as from module-level attributes assigned to a Crew instance. Additionally, crews can be retrieved from within a Flow * refactor: make add_sources a public method from Knowledge --- src/crewai/cli/reset_memories_command.py | 62 ++++++---- src/crewai/cli/utils.py | 52 ++++++-- src/crewai/crew.py | 23 +++- src/crewai/knowledge/knowledge.py | 3 +- tests/cli/cli_test.py | 151 +++++++++++++++-------- 5 files changed, 199 insertions(+), 92 deletions(-) diff --git a/src/crewai/cli/reset_memories_command.py b/src/crewai/cli/reset_memories_command.py index 4870d6424..eaf54ffb7 100644 --- a/src/crewai/cli/reset_memories_command.py +++ b/src/crewai/cli/reset_memories_command.py @@ -2,7 +2,7 @@ import subprocess import click -from crewai.cli.utils import get_crew +from crewai.cli.utils import get_crews def reset_memories_command( @@ -26,35 +26,47 @@ def reset_memories_command( """ try: - crew = get_crew() - if not crew: - raise ValueError("No crew found.") - if all: - crew.reset_memories(command_type="all") - click.echo("All memories have been reset.") - return - - if not any([long, short, entity, kickoff_outputs, knowledge]): + if not any([long, short, entity, kickoff_outputs, knowledge, all]): click.echo( "No memory type specified. Please specify at least one type to reset." ) return - if long: - crew.reset_memories(command_type="long") - click.echo("Long term memory has been reset.") - if short: - crew.reset_memories(command_type="short") - click.echo("Short term memory has been reset.") - if entity: - crew.reset_memories(command_type="entity") - click.echo("Entity memory has been reset.") - if kickoff_outputs: - crew.reset_memories(command_type="kickoff_outputs") - click.echo("Latest Kickoff outputs stored has been reset.") - if knowledge: - crew.reset_memories(command_type="knowledge") - click.echo("Knowledge has been reset.") + crews = get_crews() + if not crews: + raise ValueError("No crew found.") + for crew in crews: + if all: + crew.reset_memories(command_type="all") + click.echo( + f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed." + ) + continue + if long: + crew.reset_memories(command_type="long") + click.echo( + f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset." + ) + if short: + crew.reset_memories(command_type="short") + click.echo( + f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset." + ) + if entity: + crew.reset_memories(command_type="entity") + click.echo( + f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset." + ) + if kickoff_outputs: + crew.reset_memories(command_type="kickoff_outputs") + click.echo( + f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset." + ) + if knowledge: + crew.reset_memories(command_type="knowledge") + click.echo( + f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset." + ) except subprocess.CalledProcessError as e: click.echo(f"An error occurred while resetting the memories: {e}", err=True) diff --git a/src/crewai/cli/utils.py b/src/crewai/cli/utils.py index 078afec60..74fc414d9 100644 --- a/src/crewai/cli/utils.py +++ b/src/crewai/cli/utils.py @@ -2,7 +2,8 @@ import os import shutil import sys from functools import reduce -from typing import Any, Dict, List +from inspect import isfunction, ismethod +from typing import Any, Dict, List, get_type_hints import click import tomli @@ -10,6 +11,7 @@ from rich.console import Console from crewai.cli.constants import ENV_VARS from crewai.crew import Crew +from crewai.flow import Flow if sys.version_info >= (3, 11): import tomllib @@ -250,11 +252,11 @@ def write_env_file(folder_path, env_vars): file.write(f"{key}={value}\n") -def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None: - """Get the crew instance from the crew.py file.""" +def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: + """Get the crew instances from the a file.""" + crew_instances = [] try: import importlib.util - import os for root, _, files in os.walk("."): if crew_path in files: @@ -271,12 +273,10 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None: spec.loader.exec_module(module) for attr_name in dir(module): - attr = getattr(module, attr_name) - try: - if callable(attr) and hasattr(attr, "crew"): - crew_instance = attr().crew() - return crew_instance + module_attr = getattr(module, attr_name) + try: + crew_instances.extend(fetch_crews(module_attr)) except Exception as e: print(f"Error processing attribute {attr_name}: {e}") continue @@ -286,7 +286,6 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None: import traceback print(f"Traceback: {traceback.format_exc()}") - except (ImportError, AttributeError) as e: if require: console.print( @@ -300,7 +299,6 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None: if require: console.print("No valid Crew instance found in crew.py", style="bold red") raise SystemExit - return None except Exception as e: if require: @@ -308,4 +306,36 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None: f"Unexpected error while loading crew: {str(e)}", style="bold red" ) raise SystemExit + return crew_instances + + +def get_crew_instance(module_attr) -> Crew | None: + if ( + callable(module_attr) + and hasattr(module_attr, "is_crew_class") + and module_attr.is_crew_class + ): + return module_attr().crew() + if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints( + module_attr + ).get("return") is Crew: + return module_attr() + elif isinstance(module_attr, Crew): + return module_attr + else: return None + + +def fetch_crews(module_attr) -> list[Crew]: + crew_instances: list[Crew] = [] + + if crew_instance := get_crew_instance(module_attr): + crew_instances.append(crew_instance) + + if isinstance(module_attr, type) and issubclass(module_attr, Flow): + instance = module_attr() + for attr_name in dir(instance): + attr = getattr(instance, attr_name) + if crew_instance := get_crew_instance(attr): + crew_instances.append(crew_instance) + return crew_instances diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 7c9696f6d..75433554c 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -304,7 +304,9 @@ class Crew(BaseModel): """Initialize private memory attributes.""" self._external_memory = ( # External memory doesn’t support a default value since it was designed to be managed entirely externally - self.external_memory.set_crew(self) if self.external_memory else None + self.external_memory.set_crew(self) + if self.external_memory + else None ) self._long_term_memory = self.long_term_memory @@ -333,6 +335,7 @@ class Crew(BaseModel): embedder=self.embedder, collection_name="crew", ) + self.knowledge.add_sources() except Exception as e: self._logger.log( @@ -1369,8 +1372,6 @@ class Crew(BaseModel): else: self._reset_specific_memory(command_type) - self._logger.log("info", f"{command_type} memory has been reset") - except Exception as e: error_msg = f"Failed to reset {command_type} memory: {str(e)}" self._logger.log("error", error_msg) @@ -1391,8 +1392,14 @@ class Crew(BaseModel): if system is not None: try: system.reset() + self._logger.log( + "info", + f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset", + ) except Exception as e: - raise RuntimeError(f"Failed to reset {name} memory") from e + raise RuntimeError( + f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}" + ) from e def _reset_specific_memory(self, memory_type: str) -> None: """Reset a specific memory system. @@ -1421,5 +1428,11 @@ class Crew(BaseModel): try: memory_system.reset() + self._logger.log( + "info", + f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset", + ) except Exception as e: - raise RuntimeError(f"Failed to reset {name} memory") from e + raise RuntimeError( + f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}" + ) from e diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index 824325d12..2340dec90 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -41,7 +41,6 @@ class Knowledge(BaseModel): ) self.sources = sources self.storage.initialize_knowledge_storage() - self._add_sources() def query( self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35 @@ -63,7 +62,7 @@ class Knowledge(BaseModel): ) return results - def _add_sources(self): + def add_sources(self): try: for source in self.sources: source.storage = self.storage diff --git a/tests/cli/cli_test.py b/tests/cli/cli_test.py index dc0c502b7..19439cb82 100644 --- a/tests/cli/cli_test.py +++ b/tests/cli/cli_test.py @@ -18,6 +18,7 @@ from crewai.cli.cli import ( train, version, ) +from crewai.crew import Crew @pytest.fixture @@ -55,81 +56,133 @@ def test_train_invalid_string_iterations(train_crew, runner): ) -@mock.patch("crewai.cli.reset_memories_command.get_crew") -def test_reset_all_memories(mock_get_crew, runner): - mock_crew = mock.Mock() - mock_get_crew.return_value = mock_crew +@pytest.fixture +def mock_crew(): + _mock = mock.Mock(spec=Crew, name="test_crew") + _mock.name = "test_crew" + return _mock + + +@pytest.fixture +def mock_get_crews(mock_crew): + with mock.patch( + "crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew] + ) as mock_get_crew: + yield mock_get_crew + + +def test_reset_all_memories(mock_get_crews, runner): result = runner.invoke(reset_memories, ["-a"]) - mock_crew.reset_memories.assert_called_once_with(command_type="all") - assert result.output == "All memories have been reset.\n" + call_count = 0 + for crew in mock_get_crews.return_value: + crew.reset_memories.assert_called_once_with(command_type="all") + assert ( + f"[Crew ({crew.name})] Reset memories command has been completed." + in result.output + ) + call_count += 1 + + assert call_count == 1, "reset_memories should have been called once" -@mock.patch("crewai.cli.reset_memories_command.get_crew") -def test_reset_short_term_memories(mock_get_crew, runner): - mock_crew = mock.Mock() - mock_get_crew.return_value = mock_crew +def test_reset_short_term_memories(mock_get_crews, runner): result = runner.invoke(reset_memories, ["-s"]) + call_count = 0 + for crew in mock_get_crews.return_value: + crew.reset_memories.assert_called_once_with(command_type="short") + assert ( + f"[Crew ({crew.name})] Short term memory has been reset." in result.output + ) + call_count += 1 - mock_crew.reset_memories.assert_called_once_with(command_type="short") - assert result.output == "Short term memory has been reset.\n" + assert call_count == 1, "reset_memories should have been called once" -@mock.patch("crewai.cli.reset_memories_command.get_crew") -def test_reset_entity_memories(mock_get_crew, runner): - mock_crew = mock.Mock() - mock_get_crew.return_value = mock_crew +def test_reset_entity_memories(mock_get_crews, runner): result = runner.invoke(reset_memories, ["-e"]) + call_count = 0 + for crew in mock_get_crews.return_value: + crew.reset_memories.assert_called_once_with(command_type="entity") + assert f"[Crew ({crew.name})] Entity memory has been reset." in result.output + call_count += 1 - mock_crew.reset_memories.assert_called_once_with(command_type="entity") - assert result.output == "Entity memory has been reset.\n" + assert call_count == 1, "reset_memories should have been called once" -@mock.patch("crewai.cli.reset_memories_command.get_crew") -def test_reset_long_term_memories(mock_get_crew, runner): - mock_crew = mock.Mock() - mock_get_crew.return_value = mock_crew +def test_reset_long_term_memories(mock_get_crews, runner): result = runner.invoke(reset_memories, ["-l"]) + call_count = 0 + for crew in mock_get_crews.return_value: + crew.reset_memories.assert_called_once_with(command_type="long") + assert f"[Crew ({crew.name})] Long term memory has been reset." in result.output + call_count += 1 - mock_crew.reset_memories.assert_called_once_with(command_type="long") - assert result.output == "Long term memory has been reset.\n" + assert call_count == 1, "reset_memories should have been called once" -@mock.patch("crewai.cli.reset_memories_command.get_crew") -def test_reset_kickoff_outputs(mock_get_crew, runner): - mock_crew = mock.Mock() - mock_get_crew.return_value = mock_crew +def test_reset_kickoff_outputs(mock_get_crews, runner): result = runner.invoke(reset_memories, ["-k"]) + call_count = 0 + for crew in mock_get_crews.return_value: + crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs") + assert ( + f"[Crew ({crew.name})] Latest Kickoff outputs stored has been reset." + in result.output + ) + call_count += 1 - mock_crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs") - assert result.output == "Latest Kickoff outputs stored has been reset.\n" + assert call_count == 1, "reset_memories should have been called once" -@mock.patch("crewai.cli.reset_memories_command.get_crew") -def test_reset_multiple_memory_flags(mock_get_crew, runner): - mock_crew = mock.Mock() - mock_get_crew.return_value = mock_crew +def test_reset_multiple_memory_flags(mock_get_crews, runner): result = runner.invoke(reset_memories, ["-s", "-l"]) + call_count = 0 + for crew in mock_get_crews.return_value: + crew.reset_memories.assert_has_calls( + [mock.call(command_type="long"), mock.call(command_type="short")] + ) + assert ( + f"[Crew ({crew.name})] Long term memory has been reset.\n" + f"[Crew ({crew.name})] Short term memory has been reset.\n" in result.output + ) + call_count += 1 - # Check that reset_memories was called twice with the correct arguments - assert mock_crew.reset_memories.call_count == 2 - mock_crew.reset_memories.assert_has_calls( - [mock.call(command_type="long"), mock.call(command_type="short")] - ) - assert ( - result.output - == "Long term memory has been reset.\nShort term memory has been reset.\n" - ) + assert call_count == 1, "reset_memories should have been called once" -@mock.patch("crewai.cli.reset_memories_command.get_crew") -def test_reset_knowledge(mock_get_crew, runner): - mock_crew = mock.Mock() - mock_get_crew.return_value = mock_crew +def test_reset_knowledge(mock_get_crews, runner): + result = runner.invoke(reset_memories, ["--knowledge"]) + call_count = 0 + for crew in mock_get_crews.return_value: + crew.reset_memories.assert_called_once_with(command_type="knowledge") + assert f"[Crew ({crew.name})] Knowledge has been reset." in result.output + call_count += 1 + + assert call_count == 1, "reset_memories should have been called once" + + +def test_reset_memory_from_many_crews(mock_get_crews, runner): + + crews = [] + for crew_id in ["id-1234", "id-5678"]: + mock_crew = mock.Mock(spec=Crew) + mock_crew.name = None + mock_crew.id = crew_id + crews.append(mock_crew) + + mock_get_crews.return_value = crews + + # Run the command result = runner.invoke(reset_memories, ["--knowledge"]) - mock_crew.reset_memories.assert_called_once_with(command_type="knowledge") - assert result.output == "Knowledge has been reset.\n" + call_count = 0 + for crew in crews: + call_count += 1 + crew.reset_memories.assert_called_once_with(command_type="knowledge") + assert f"[Crew ({crew.id})] Knowledge has been reset." in result.output + + assert call_count == 2, "reset_memories should have been called twice" def test_reset_no_memory_flags(runner):