Compare commits

...

2 Commits

Author SHA1 Message Date
Joao Moura
d2e5775d7b Fix memory reset review issues 2026-06-16 18:47:26 -07:00
Joao Moura
6c185d22c1 Enhance memory reset functionality and JSON crew handling
- Added `reset_all` method to the `Memory` class to reset the entire memory store, ignoring `root_scope`.
- Updated the `Crew` class to utilize `reset_all` when resetting memory.
- Enhanced the `_reset_flow_memory` function to check for `Memory` instances and call `reset_all` accordingly.
- Introduced helper functions to load JSON crew configurations and handle project declarations, improving the reset command's flexibility.
- Added tests to validate the new JSON crew memory reset behavior and ensure proper handling of declared flow projects.
2026-06-16 16:27:07 -07:00
7 changed files with 277 additions and 23 deletions

View File

@@ -2275,6 +2275,8 @@ class Crew(FlowTrackable, BaseModel):
"""
def default_reset(memory: Any) -> Any:
if isinstance(memory, Memory):
return memory.reset_all()
return memory.reset()
def knowledge_reset(memory: Any) -> Any:

View File

@@ -149,6 +149,7 @@ class Memory(BaseModel):
)
_pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list)
_pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_reset_lock: Any = PrivateAttr(default_factory=threading.RLock)
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Memory:
"""Deepcopy that handles unpickleable private attrs (ThreadPoolExecutor, Lock)."""
@@ -168,7 +169,10 @@ class Memory(BaseModel):
)
private = {}
for k, v in (self.__pydantic_private__ or {}).items():
if isinstance(v, (ThreadPoolExecutor, threading.Lock)):
if k in {"_save_pool", "_pending_lock", "_reset_lock"}:
attr = self.__private_attributes__[k]
private[k] = attr.get_default()
elif isinstance(v, (ThreadPoolExecutor, threading.Lock)):
attr = self.__private_attributes__[k]
private[k] = attr.get_default()
else:
@@ -275,22 +279,25 @@ class Memory(BaseModel):
If the pool has been shut down (e.g. after ``close()``), the save
runs synchronously as a fallback so late saves still succeed.
"""
ctx = contextvars.copy_context()
try:
future: Future[Any] = self._save_pool.submit(ctx.run, fn, *args, **kwargs)
except RuntimeError:
# Pool shut down -- run synchronously as fallback
future = Future()
with self._reset_lock:
ctx = contextvars.copy_context()
try:
result = fn(*args, **kwargs)
future.set_result(result)
except Exception as exc:
future.set_exception(exc)
future: Future[Any] = self._save_pool.submit(
ctx.run, fn, *args, **kwargs
)
except RuntimeError:
# Pool shut down -- run synchronously as fallback
future = Future()
try:
result = fn(*args, **kwargs)
future.set_result(result)
except Exception as exc:
future.set_exception(exc)
return future
with self._pending_lock:
self._pending_saves.append(future)
future.add_done_callback(self._on_save_done)
return future
with self._pending_lock:
self._pending_saves.append(future)
future.add_done_callback(self._on_save_done)
return future
def _on_save_done(self, future: Future[Any]) -> None:
"""Remove a completed future from the pending list and emit failure event if needed.
@@ -990,12 +997,20 @@ class Memory(BaseModel):
scope: Scope to reset. If None and root_scope is set, resets only
within root_scope. If None and no root_scope, resets all.
"""
effective_scope = scope
if effective_scope is None and self.root_scope:
effective_scope = self.root_scope
elif effective_scope is not None and self.root_scope:
effective_scope = join_scope_paths(self.root_scope, effective_scope)
self._storage.reset(scope_prefix=effective_scope)
with self._reset_lock:
self.drain_writes()
effective_scope = scope
if effective_scope is None and self.root_scope:
effective_scope = self.root_scope
elif effective_scope is not None and self.root_scope:
effective_scope = join_scope_paths(self.root_scope, effective_scope)
self._storage.reset(scope_prefix=effective_scope)
def reset_all(self) -> None:
"""Reset the entire backing memory store, ignoring ``root_scope``."""
with self._reset_lock:
self.drain_writes()
self._storage.reset(scope_prefix=None)
async def aextract_memories(self, content: str) -> list[str]:
"""Async variant of extract_memories."""

View File

@@ -6,7 +6,10 @@ from typing import Any
import click
from crewai.flow import Flow
from crewai.utilities.project_utils import get_crews, get_flows
from crewai.memory.unified_memory import Memory
from crewai.project.crew_loader import load_crew
from crewai.project.json_loader import find_crew_json_file
from crewai.utilities.project_utils import get_crews, get_flows, read_toml
def _reset_flow_memory(flow: Flow[Any]) -> None:
@@ -23,7 +26,9 @@ def _reset_flow_memory(flow: Flow[Any]) -> None:
if mem is None:
return
try:
if hasattr(mem, "reset"):
if isinstance(mem, Memory):
mem.reset_all()
elif hasattr(mem, "reset"):
mem.reset()
elif hasattr(mem, "_memory") and mem._memory is not None:
mem._memory.reset()
@@ -37,6 +42,38 @@ def _reset_flow_memory(flow: Flow[Any]) -> None:
click.echo(f"Memory reset skipped: {exc}", err=True)
def _current_project_declares_flow() -> bool:
try:
pyproject_data = read_toml()
except Exception:
return False
declared_type: str | None = (
pyproject_data.get("tool", {}).get("crewai", {}).get("type")
)
return declared_type == "flow"
def _get_json_crew() -> Any | None:
"""Load a JSON-first crew from the current project, if present."""
if _current_project_declares_flow():
return None
crew_path = find_crew_json_file()
if crew_path is None:
return None
try:
crew, _ = load_crew(crew_path)
except Exception as exc:
click.echo(
f"Skipping JSON crew at {crew_path}: failed to load ({exc}).",
err=True,
)
return None
return crew
def reset_memories_command(
memory: bool,
knowledge: bool,
@@ -61,6 +98,8 @@ def reset_memories_command(
return
crews = get_crews()
if json_crew := _get_json_crew():
crews.append(json_crew)
flows = get_flows()
if not crews and not flows:

View File

@@ -4,10 +4,12 @@ Non-core CLI tests (train, test, version, deploy, login, flow_add_crew)
have moved to lib/cli/tests/test_cli.py.
"""
from pathlib import Path
from unittest import mock
from click.testing import CliRunner
from crewai.crew import Crew
from crewai.memory.unified_memory import Memory
from crewai_cli.cli import reset_memories
import pytest
@@ -30,6 +32,8 @@ def mock_get_crews(mock_crew):
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
) as mock_get_crew, mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
):
yield mock_get_crew
@@ -170,6 +174,8 @@ def mock_get_flows(mock_flow):
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
) as mock_get_flow, mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
):
yield mock_get_flow
@@ -180,6 +186,33 @@ def test_reset_flow_memory(mock_get_flows, mock_flow, runner):
assert "[Flow (TestFlow)] Memory has been reset." in result.output
def test_reset_flow_unified_memory_uses_full_reset(runner, tmp_path):
flow = mock.Mock()
flow.name = "TestFlow"
flow.memory = Memory(
storage=str(tmp_path / "db"),
llm=mock.Mock(),
embedder=lambda texts: [[0.1] * 4 for _ in texts],
)
with mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[flow]
), mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
), mock.patch.object(
Memory, "reset_all"
) as reset_all, mock.patch.object(
Memory, "reset"
) as reset:
result = runner.invoke(reset_memories, ["-m"])
reset_all.assert_called_once_with()
reset.assert_not_called()
assert "[Flow (TestFlow)] Memory has been reset." in result.output
def test_reset_flow_all_memories(mock_get_flows, mock_flow, runner):
result = runner.invoke(reset_memories, ["-a"])
mock_flow.memory.reset.assert_called_once()
@@ -197,16 +230,83 @@ def test_reset_no_crew_or_flow_found(runner):
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
):
result = runner.invoke(reset_memories, ["-m"])
assert "No crew or flow found." in result.output
def test_reset_json_crew_memory(mock_crew, runner, monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
(tmp_path / "crew.jsonc").write_text("{}")
with mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.load_crew",
return_value=(mock_crew, {}),
) as mock_load_crew:
result = runner.invoke(reset_memories, ["-m"])
mock_load_crew.assert_called_once_with(Path("crew.jsonc"))
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
def test_reset_invalid_json_crew_does_not_block_classic_crew(
mock_crew, runner, monkeypatch, tmp_path
):
monkeypatch.chdir(tmp_path)
(tmp_path / "crew.jsonc").write_text("{invalid")
with mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.load_crew",
side_effect=ValueError("invalid JSON"),
) as mock_load_crew:
result = runner.invoke(reset_memories, ["-m"])
mock_load_crew.assert_called_once_with(Path("crew.jsonc"))
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
assert "Skipping JSON crew at crew.jsonc: failed to load (invalid JSON)." in result.output
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
def test_reset_json_crew_skipped_for_declared_flow_project(
mock_crew, runner, monkeypatch, tmp_path
):
monkeypatch.chdir(tmp_path)
(tmp_path / "crew.jsonc").write_text("{}")
(tmp_path / "pyproject.toml").write_text('[tool.crewai]\ntype = "flow"\n')
with mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.load_crew",
return_value=(mock_crew, {}),
) as mock_load_crew:
result = runner.invoke(reset_memories, ["-m"])
mock_load_crew.assert_not_called()
mock_crew.reset_memories.assert_not_called()
assert "No crew or flow found." in result.output
def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner):
with mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
):
result = runner.invoke(reset_memories, ["-m"])
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
@@ -223,6 +323,8 @@ def test_reset_flow_memory_none(runner):
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
):
result = runner.invoke(reset_memories, ["-m"])
assert "[Flow (NoMemFlow)] Memory has been reset." in result.output

View File

@@ -8,6 +8,7 @@ not silently zero-fill vectors or return empty search results.
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock
import pytest
@@ -97,6 +98,33 @@ def test_lancedb_reopened_store_detects_mismatch(lancedb_path: Path) -> None:
reopened.search([0.1] * 8)
def test_memory_reset_all_rebuilds_reopened_store_with_new_dimension(
lancedb_path: Path,
) -> None:
from crewai.memory.storage.lancedb_storage import LanceDBStorage
from crewai.memory.unified_memory import Memory
old = LanceDBStorage(path=str(lancedb_path), vector_dim=4)
old.save([_record(4)])
mem = Memory(
storage=str(lancedb_path),
llm=MagicMock(),
embedder=lambda texts: [[0.1] * 8 for _ in texts],
root_scope="/crew/test",
)
mem.reset_all()
mem.remember(
"new embedder output",
scope="/facts",
categories=["test"],
importance=0.5,
)
assert mem.recall("new embedder output", scope="/facts", depth="shallow")
def test_lancedb_matching_dim_still_works(lancedb_path: Path) -> None:
from crewai.memory.storage.lancedb_storage import LanceDBStorage

View File

@@ -954,6 +954,54 @@ def test_remember_many_returns_immediately(tmp_path: Path) -> None:
assert mem._storage.count() == 2
def test_reset_all_blocks_new_save_submission_until_reset_completes(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""A save cannot be submitted between draining writes and resetting storage."""
from crewai.memory.unified_memory import Memory
mem = Memory(
storage=str(tmp_path / "db"),
llm=MagicMock(),
embedder=lambda texts: [[0.1] * 4 for _ in texts],
)
reset_started = threading.Event()
release_reset = threading.Event()
submission_returned = threading.Event()
order: list[str] = []
original_reset = mem._storage.reset
def blocking_reset(scope_prefix: str | None = None) -> None:
order.append("reset-start")
reset_started.set()
assert release_reset.wait(timeout=2)
original_reset(scope_prefix=scope_prefix)
order.append("reset-end")
def submit_save() -> None:
mem._submit_save(lambda: order.append("save"))
order.append("submit-returned")
submission_returned.set()
monkeypatch.setattr(mem._storage, "reset", blocking_reset)
reset_thread = threading.Thread(target=mem.reset_all)
reset_thread.start()
assert reset_started.wait(timeout=2)
submit_thread = threading.Thread(target=submit_save)
submit_thread.start()
assert not submission_returned.wait(timeout=0.1)
release_reset.set()
reset_thread.join(timeout=2)
submit_thread.join(timeout=2)
assert not reset_thread.is_alive()
assert not submit_thread.is_alive()
assert order.index("reset-end") < order.index("submit-returned")
def test_recall_drains_pending_writes(tmp_path: Path, mock_embedder: MagicMock) -> None:
"""recall() should automatically wait for pending background saves."""
from crewai.memory.unified_memory import Memory

View File

@@ -4584,6 +4584,26 @@ def test_reset_knowledge_with_no_crew_knowledge(researcher, writer):
)
def test_reset_memory_uses_full_unified_memory_reset(researcher):
crew = Crew(
agents=[researcher],
process=Process.sequential,
tasks=[
Task(description="Task 1", expected_output="output", agent=researcher),
],
memory=True,
)
assert isinstance(crew._memory, Memory)
with patch.object(Memory, "reset_all") as reset_all, patch.object(
Memory, "reset"
) as reset:
crew.reset_memories(command_type="memory")
reset_all.assert_called_once_with()
reset.assert_not_called()
def test_reset_knowledge_with_only_crew_knowledge(researcher, writer):
mock_ks = MagicMock(spec=Knowledge)