Fix crewai reset-memories when Embedding dimension mismatch (#2737)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled

* fix: support to reset memories after changing Crew's embedder

The sources must not be added while initializing the Knowledge otherwise we could not reset it

* chore: improve reset memory feedback

Previously, even when no memories were actually erased, we logged that they had been. From now on, the log will specify which memory has been reset.

* feat: improve get_crew discovery from a single file

Crew instances can now be discovered from any function or method with a return type annotation of -> Crew, as well as from module-level attributes assigned to a Crew instance. Additionally, crews can be retrieved from within a Flow

* refactor: make add_sources a public method from Knowledge
This commit is contained in:
Lucas Gomide
2025-05-02 13:40:42 -03:00
committed by GitHub
parent 2902201bfa
commit f89c2bfb7e
5 changed files with 199 additions and 92 deletions

View File

@@ -2,7 +2,7 @@ import subprocess
import click import click
from crewai.cli.utils import get_crew from crewai.cli.utils import get_crews
def reset_memories_command( def reset_memories_command(
@@ -26,35 +26,47 @@ def reset_memories_command(
""" """
try: try:
crew = get_crew() if not any([long, short, entity, kickoff_outputs, knowledge, all]):
if not crew:
raise ValueError("No crew found.")
if all:
crew.reset_memories(command_type="all")
click.echo("All memories have been reset.")
return
if not any([long, short, entity, kickoff_outputs, knowledge]):
click.echo( click.echo(
"No memory type specified. Please specify at least one type to reset." "No memory type specified. Please specify at least one type to reset."
) )
return return
if long: crews = get_crews()
crew.reset_memories(command_type="long") if not crews:
click.echo("Long term memory has been reset.") raise ValueError("No crew found.")
if short: for crew in crews:
crew.reset_memories(command_type="short") if all:
click.echo("Short term memory has been reset.") crew.reset_memories(command_type="all")
if entity: click.echo(
crew.reset_memories(command_type="entity") f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed."
click.echo("Entity memory has been reset.") )
if kickoff_outputs: continue
crew.reset_memories(command_type="kickoff_outputs") if long:
click.echo("Latest Kickoff outputs stored has been reset.") crew.reset_memories(command_type="long")
if knowledge: click.echo(
crew.reset_memories(command_type="knowledge") f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset."
click.echo("Knowledge has been reset.") )
if short:
crew.reset_memories(command_type="short")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset."
)
if entity:
crew.reset_memories(command_type="entity")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset."
)
if kickoff_outputs:
crew.reset_memories(command_type="kickoff_outputs")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset."
)
if knowledge:
crew.reset_memories(command_type="knowledge")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset."
)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while resetting the memories: {e}", err=True) click.echo(f"An error occurred while resetting the memories: {e}", err=True)

View File

@@ -2,7 +2,8 @@ import os
import shutil import shutil
import sys import sys
from functools import reduce from functools import reduce
from typing import Any, Dict, List from inspect import isfunction, ismethod
from typing import Any, Dict, List, get_type_hints
import click import click
import tomli import tomli
@@ -10,6 +11,7 @@ from rich.console import Console
from crewai.cli.constants import ENV_VARS from crewai.cli.constants import ENV_VARS
from crewai.crew import Crew from crewai.crew import Crew
from crewai.flow import Flow
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
import tomllib import tomllib
@@ -250,11 +252,11 @@ def write_env_file(folder_path, env_vars):
file.write(f"{key}={value}\n") file.write(f"{key}={value}\n")
def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None: def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
"""Get the crew instance from the crew.py file.""" """Get the crew instances from the a file."""
crew_instances = []
try: try:
import importlib.util import importlib.util
import os
for root, _, files in os.walk("."): for root, _, files in os.walk("."):
if crew_path in files: if crew_path in files:
@@ -271,12 +273,10 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
spec.loader.exec_module(module) spec.loader.exec_module(module)
for attr_name in dir(module): for attr_name in dir(module):
attr = getattr(module, attr_name) module_attr = getattr(module, attr_name)
try:
if callable(attr) and hasattr(attr, "crew"):
crew_instance = attr().crew()
return crew_instance
try:
crew_instances.extend(fetch_crews(module_attr))
except Exception as e: except Exception as e:
print(f"Error processing attribute {attr_name}: {e}") print(f"Error processing attribute {attr_name}: {e}")
continue continue
@@ -286,7 +286,6 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
import traceback import traceback
print(f"Traceback: {traceback.format_exc()}") print(f"Traceback: {traceback.format_exc()}")
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
if require: if require:
console.print( console.print(
@@ -300,7 +299,6 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
if require: if require:
console.print("No valid Crew instance found in crew.py", style="bold red") console.print("No valid Crew instance found in crew.py", style="bold red")
raise SystemExit raise SystemExit
return None
except Exception as e: except Exception as e:
if require: if require:
@@ -308,4 +306,36 @@ def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
f"Unexpected error while loading crew: {str(e)}", style="bold red" f"Unexpected error while loading crew: {str(e)}", style="bold red"
) )
raise SystemExit raise SystemExit
return crew_instances
def get_crew_instance(module_attr) -> Crew | None:
if (
callable(module_attr)
and hasattr(module_attr, "is_crew_class")
and module_attr.is_crew_class
):
return module_attr().crew()
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
module_attr
).get("return") is Crew:
return module_attr()
elif isinstance(module_attr, Crew):
return module_attr
else:
return None return None
def fetch_crews(module_attr) -> list[Crew]:
crew_instances: list[Crew] = []
if crew_instance := get_crew_instance(module_attr):
crew_instances.append(crew_instance)
if isinstance(module_attr, type) and issubclass(module_attr, Flow):
instance = module_attr()
for attr_name in dir(instance):
attr = getattr(instance, attr_name)
if crew_instance := get_crew_instance(attr):
crew_instances.append(crew_instance)
return crew_instances

View File

@@ -304,7 +304,9 @@ class Crew(BaseModel):
"""Initialize private memory attributes.""" """Initialize private memory attributes."""
self._external_memory = ( self._external_memory = (
# External memory doesnt support a default value since it was designed to be managed entirely externally # External memory doesnt support a default value since it was designed to be managed entirely externally
self.external_memory.set_crew(self) if self.external_memory else None self.external_memory.set_crew(self)
if self.external_memory
else None
) )
self._long_term_memory = self.long_term_memory self._long_term_memory = self.long_term_memory
@@ -333,6 +335,7 @@ class Crew(BaseModel):
embedder=self.embedder, embedder=self.embedder,
collection_name="crew", collection_name="crew",
) )
self.knowledge.add_sources()
except Exception as e: except Exception as e:
self._logger.log( self._logger.log(
@@ -1369,8 +1372,6 @@ class Crew(BaseModel):
else: else:
self._reset_specific_memory(command_type) self._reset_specific_memory(command_type)
self._logger.log("info", f"{command_type} memory has been reset")
except Exception as e: except Exception as e:
error_msg = f"Failed to reset {command_type} memory: {str(e)}" error_msg = f"Failed to reset {command_type} memory: {str(e)}"
self._logger.log("error", error_msg) self._logger.log("error", error_msg)
@@ -1391,8 +1392,14 @@ class Crew(BaseModel):
if system is not None: if system is not None:
try: try:
system.reset() system.reset()
self._logger.log(
"info",
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to reset {name} memory") from e raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
) from e
def _reset_specific_memory(self, memory_type: str) -> None: def _reset_specific_memory(self, memory_type: str) -> None:
"""Reset a specific memory system. """Reset a specific memory system.
@@ -1421,5 +1428,11 @@ class Crew(BaseModel):
try: try:
memory_system.reset() memory_system.reset()
self._logger.log(
"info",
f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset",
)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to reset {name} memory") from e raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
) from e

View File

@@ -41,7 +41,6 @@ class Knowledge(BaseModel):
) )
self.sources = sources self.sources = sources
self.storage.initialize_knowledge_storage() self.storage.initialize_knowledge_storage()
self._add_sources()
def query( def query(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35 self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
@@ -63,7 +62,7 @@ class Knowledge(BaseModel):
) )
return results return results
def _add_sources(self): def add_sources(self):
try: try:
for source in self.sources: for source in self.sources:
source.storage = self.storage source.storage = self.storage

View File

@@ -18,6 +18,7 @@ from crewai.cli.cli import (
train, train,
version, version,
) )
from crewai.crew import Crew
@pytest.fixture @pytest.fixture
@@ -55,81 +56,133 @@ def test_train_invalid_string_iterations(train_crew, runner):
) )
@mock.patch("crewai.cli.reset_memories_command.get_crew") @pytest.fixture
def test_reset_all_memories(mock_get_crew, runner): def mock_crew():
mock_crew = mock.Mock() _mock = mock.Mock(spec=Crew, name="test_crew")
mock_get_crew.return_value = mock_crew _mock.name = "test_crew"
return _mock
@pytest.fixture
def mock_get_crews(mock_crew):
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
) as mock_get_crew:
yield mock_get_crew
def test_reset_all_memories(mock_get_crews, runner):
result = runner.invoke(reset_memories, ["-a"]) result = runner.invoke(reset_memories, ["-a"])
mock_crew.reset_memories.assert_called_once_with(command_type="all") call_count = 0
assert result.output == "All memories have been reset.\n" for crew in mock_get_crews.return_value:
crew.reset_memories.assert_called_once_with(command_type="all")
assert (
f"[Crew ({crew.name})] Reset memories command has been completed."
in result.output
)
call_count += 1
assert call_count == 1, "reset_memories should have been called once"
@mock.patch("crewai.cli.reset_memories_command.get_crew") def test_reset_short_term_memories(mock_get_crews, runner):
def test_reset_short_term_memories(mock_get_crew, runner):
mock_crew = mock.Mock()
mock_get_crew.return_value = mock_crew
result = runner.invoke(reset_memories, ["-s"]) result = runner.invoke(reset_memories, ["-s"])
call_count = 0
for crew in mock_get_crews.return_value:
crew.reset_memories.assert_called_once_with(command_type="short")
assert (
f"[Crew ({crew.name})] Short term memory has been reset." in result.output
)
call_count += 1
mock_crew.reset_memories.assert_called_once_with(command_type="short") assert call_count == 1, "reset_memories should have been called once"
assert result.output == "Short term memory has been reset.\n"
@mock.patch("crewai.cli.reset_memories_command.get_crew") def test_reset_entity_memories(mock_get_crews, runner):
def test_reset_entity_memories(mock_get_crew, runner):
mock_crew = mock.Mock()
mock_get_crew.return_value = mock_crew
result = runner.invoke(reset_memories, ["-e"]) result = runner.invoke(reset_memories, ["-e"])
call_count = 0
for crew in mock_get_crews.return_value:
crew.reset_memories.assert_called_once_with(command_type="entity")
assert f"[Crew ({crew.name})] Entity memory has been reset." in result.output
call_count += 1
mock_crew.reset_memories.assert_called_once_with(command_type="entity") assert call_count == 1, "reset_memories should have been called once"
assert result.output == "Entity memory has been reset.\n"
@mock.patch("crewai.cli.reset_memories_command.get_crew") def test_reset_long_term_memories(mock_get_crews, runner):
def test_reset_long_term_memories(mock_get_crew, runner):
mock_crew = mock.Mock()
mock_get_crew.return_value = mock_crew
result = runner.invoke(reset_memories, ["-l"]) result = runner.invoke(reset_memories, ["-l"])
call_count = 0
for crew in mock_get_crews.return_value:
crew.reset_memories.assert_called_once_with(command_type="long")
assert f"[Crew ({crew.name})] Long term memory has been reset." in result.output
call_count += 1
mock_crew.reset_memories.assert_called_once_with(command_type="long") assert call_count == 1, "reset_memories should have been called once"
assert result.output == "Long term memory has been reset.\n"
@mock.patch("crewai.cli.reset_memories_command.get_crew") def test_reset_kickoff_outputs(mock_get_crews, runner):
def test_reset_kickoff_outputs(mock_get_crew, runner):
mock_crew = mock.Mock()
mock_get_crew.return_value = mock_crew
result = runner.invoke(reset_memories, ["-k"]) result = runner.invoke(reset_memories, ["-k"])
call_count = 0
for crew in mock_get_crews.return_value:
crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs")
assert (
f"[Crew ({crew.name})] Latest Kickoff outputs stored has been reset."
in result.output
)
call_count += 1
mock_crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs") assert call_count == 1, "reset_memories should have been called once"
assert result.output == "Latest Kickoff outputs stored has been reset.\n"
@mock.patch("crewai.cli.reset_memories_command.get_crew") def test_reset_multiple_memory_flags(mock_get_crews, runner):
def test_reset_multiple_memory_flags(mock_get_crew, runner):
mock_crew = mock.Mock()
mock_get_crew.return_value = mock_crew
result = runner.invoke(reset_memories, ["-s", "-l"]) result = runner.invoke(reset_memories, ["-s", "-l"])
call_count = 0
for crew in mock_get_crews.return_value:
crew.reset_memories.assert_has_calls(
[mock.call(command_type="long"), mock.call(command_type="short")]
)
assert (
f"[Crew ({crew.name})] Long term memory has been reset.\n"
f"[Crew ({crew.name})] Short term memory has been reset.\n" in result.output
)
call_count += 1
# Check that reset_memories was called twice with the correct arguments assert call_count == 1, "reset_memories should have been called once"
assert mock_crew.reset_memories.call_count == 2
mock_crew.reset_memories.assert_has_calls(
[mock.call(command_type="long"), mock.call(command_type="short")]
)
assert (
result.output
== "Long term memory has been reset.\nShort term memory has been reset.\n"
)
@mock.patch("crewai.cli.reset_memories_command.get_crew") def test_reset_knowledge(mock_get_crews, runner):
def test_reset_knowledge(mock_get_crew, runner): result = runner.invoke(reset_memories, ["--knowledge"])
mock_crew = mock.Mock() call_count = 0
mock_get_crew.return_value = mock_crew for crew in mock_get_crews.return_value:
crew.reset_memories.assert_called_once_with(command_type="knowledge")
assert f"[Crew ({crew.name})] Knowledge has been reset." in result.output
call_count += 1
assert call_count == 1, "reset_memories should have been called once"
def test_reset_memory_from_many_crews(mock_get_crews, runner):
crews = []
for crew_id in ["id-1234", "id-5678"]:
mock_crew = mock.Mock(spec=Crew)
mock_crew.name = None
mock_crew.id = crew_id
crews.append(mock_crew)
mock_get_crews.return_value = crews
# Run the command
result = runner.invoke(reset_memories, ["--knowledge"]) result = runner.invoke(reset_memories, ["--knowledge"])
mock_crew.reset_memories.assert_called_once_with(command_type="knowledge") call_count = 0
assert result.output == "Knowledge has been reset.\n" for crew in crews:
call_count += 1
crew.reset_memories.assert_called_once_with(command_type="knowledge")
assert f"[Crew ({crew.id})] Knowledge has been reset." in result.output
assert call_count == 2, "reset_memories should have been called twice"
def test_reset_no_memory_flags(runner): def test_reset_no_memory_flags(runner):