mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-17 06:08:22 +00:00
Compare commits
2 Commits
flow-scrip
...
joaomdmour
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d2e5775d7b | ||
|
|
6c185d22c1 |
@@ -2275,6 +2275,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
|
||||
def default_reset(memory: Any) -> Any:
|
||||
if isinstance(memory, Memory):
|
||||
return memory.reset_all()
|
||||
return memory.reset()
|
||||
|
||||
def knowledge_reset(memory: Any) -> Any:
|
||||
|
||||
@@ -149,6 +149,7 @@ class Memory(BaseModel):
|
||||
)
|
||||
_pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list)
|
||||
_pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
_reset_lock: Any = PrivateAttr(default_factory=threading.RLock)
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Memory:
|
||||
"""Deepcopy that handles unpickleable private attrs (ThreadPoolExecutor, Lock)."""
|
||||
@@ -168,7 +169,10 @@ class Memory(BaseModel):
|
||||
)
|
||||
private = {}
|
||||
for k, v in (self.__pydantic_private__ or {}).items():
|
||||
if isinstance(v, (ThreadPoolExecutor, threading.Lock)):
|
||||
if k in {"_save_pool", "_pending_lock", "_reset_lock"}:
|
||||
attr = self.__private_attributes__[k]
|
||||
private[k] = attr.get_default()
|
||||
elif isinstance(v, (ThreadPoolExecutor, threading.Lock)):
|
||||
attr = self.__private_attributes__[k]
|
||||
private[k] = attr.get_default()
|
||||
else:
|
||||
@@ -275,22 +279,25 @@ class Memory(BaseModel):
|
||||
If the pool has been shut down (e.g. after ``close()``), the save
|
||||
runs synchronously as a fallback so late saves still succeed.
|
||||
"""
|
||||
ctx = contextvars.copy_context()
|
||||
try:
|
||||
future: Future[Any] = self._save_pool.submit(ctx.run, fn, *args, **kwargs)
|
||||
except RuntimeError:
|
||||
# Pool shut down -- run synchronously as fallback
|
||||
future = Future()
|
||||
with self._reset_lock:
|
||||
ctx = contextvars.copy_context()
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
future.set_result(result)
|
||||
except Exception as exc:
|
||||
future.set_exception(exc)
|
||||
future: Future[Any] = self._save_pool.submit(
|
||||
ctx.run, fn, *args, **kwargs
|
||||
)
|
||||
except RuntimeError:
|
||||
# Pool shut down -- run synchronously as fallback
|
||||
future = Future()
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
future.set_result(result)
|
||||
except Exception as exc:
|
||||
future.set_exception(exc)
|
||||
return future
|
||||
with self._pending_lock:
|
||||
self._pending_saves.append(future)
|
||||
future.add_done_callback(self._on_save_done)
|
||||
return future
|
||||
with self._pending_lock:
|
||||
self._pending_saves.append(future)
|
||||
future.add_done_callback(self._on_save_done)
|
||||
return future
|
||||
|
||||
def _on_save_done(self, future: Future[Any]) -> None:
|
||||
"""Remove a completed future from the pending list and emit failure event if needed.
|
||||
@@ -990,12 +997,20 @@ class Memory(BaseModel):
|
||||
scope: Scope to reset. If None and root_scope is set, resets only
|
||||
within root_scope. If None and no root_scope, resets all.
|
||||
"""
|
||||
effective_scope = scope
|
||||
if effective_scope is None and self.root_scope:
|
||||
effective_scope = self.root_scope
|
||||
elif effective_scope is not None and self.root_scope:
|
||||
effective_scope = join_scope_paths(self.root_scope, effective_scope)
|
||||
self._storage.reset(scope_prefix=effective_scope)
|
||||
with self._reset_lock:
|
||||
self.drain_writes()
|
||||
effective_scope = scope
|
||||
if effective_scope is None and self.root_scope:
|
||||
effective_scope = self.root_scope
|
||||
elif effective_scope is not None and self.root_scope:
|
||||
effective_scope = join_scope_paths(self.root_scope, effective_scope)
|
||||
self._storage.reset(scope_prefix=effective_scope)
|
||||
|
||||
def reset_all(self) -> None:
|
||||
"""Reset the entire backing memory store, ignoring ``root_scope``."""
|
||||
with self._reset_lock:
|
||||
self.drain_writes()
|
||||
self._storage.reset(scope_prefix=None)
|
||||
|
||||
async def aextract_memories(self, content: str) -> list[str]:
|
||||
"""Async variant of extract_memories."""
|
||||
|
||||
@@ -6,7 +6,10 @@ from typing import Any
|
||||
import click
|
||||
|
||||
from crewai.flow import Flow
|
||||
from crewai.utilities.project_utils import get_crews, get_flows
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.project.crew_loader import load_crew
|
||||
from crewai.project.json_loader import find_crew_json_file
|
||||
from crewai.utilities.project_utils import get_crews, get_flows, read_toml
|
||||
|
||||
|
||||
def _reset_flow_memory(flow: Flow[Any]) -> None:
|
||||
@@ -23,7 +26,9 @@ def _reset_flow_memory(flow: Flow[Any]) -> None:
|
||||
if mem is None:
|
||||
return
|
||||
try:
|
||||
if hasattr(mem, "reset"):
|
||||
if isinstance(mem, Memory):
|
||||
mem.reset_all()
|
||||
elif hasattr(mem, "reset"):
|
||||
mem.reset()
|
||||
elif hasattr(mem, "_memory") and mem._memory is not None:
|
||||
mem._memory.reset()
|
||||
@@ -37,6 +42,38 @@ def _reset_flow_memory(flow: Flow[Any]) -> None:
|
||||
click.echo(f"Memory reset skipped: {exc}", err=True)
|
||||
|
||||
|
||||
def _current_project_declares_flow() -> bool:
|
||||
try:
|
||||
pyproject_data = read_toml()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
declared_type: str | None = (
|
||||
pyproject_data.get("tool", {}).get("crewai", {}).get("type")
|
||||
)
|
||||
return declared_type == "flow"
|
||||
|
||||
|
||||
def _get_json_crew() -> Any | None:
|
||||
"""Load a JSON-first crew from the current project, if present."""
|
||||
if _current_project_declares_flow():
|
||||
return None
|
||||
|
||||
crew_path = find_crew_json_file()
|
||||
if crew_path is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
crew, _ = load_crew(crew_path)
|
||||
except Exception as exc:
|
||||
click.echo(
|
||||
f"Skipping JSON crew at {crew_path}: failed to load ({exc}).",
|
||||
err=True,
|
||||
)
|
||||
return None
|
||||
return crew
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
memory: bool,
|
||||
knowledge: bool,
|
||||
@@ -61,6 +98,8 @@ def reset_memories_command(
|
||||
return
|
||||
|
||||
crews = get_crews()
|
||||
if json_crew := _get_json_crew():
|
||||
crews.append(json_crew)
|
||||
flows = get_flows()
|
||||
|
||||
if not crews and not flows:
|
||||
|
||||
@@ -4,10 +4,12 @@ Non-core CLI tests (train, test, version, deploy, login, flow_add_crew)
|
||||
have moved to lib/cli/tests/test_cli.py.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from click.testing import CliRunner
|
||||
from crewai.crew import Crew
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai_cli.cli import reset_memories
|
||||
import pytest
|
||||
|
||||
@@ -30,6 +32,8 @@ def mock_get_crews(mock_crew):
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
|
||||
) as mock_get_crew, mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
):
|
||||
yield mock_get_crew
|
||||
|
||||
@@ -170,6 +174,8 @@ def mock_get_flows(mock_flow):
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
|
||||
) as mock_get_flow, mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
):
|
||||
yield mock_get_flow
|
||||
|
||||
@@ -180,6 +186,33 @@ def test_reset_flow_memory(mock_get_flows, mock_flow, runner):
|
||||
assert "[Flow (TestFlow)] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_flow_unified_memory_uses_full_reset(runner, tmp_path):
|
||||
flow = mock.Mock()
|
||||
flow.name = "TestFlow"
|
||||
flow.memory = Memory(
|
||||
storage=str(tmp_path / "db"),
|
||||
llm=mock.Mock(),
|
||||
embedder=lambda texts: [[0.1] * 4 for _ in texts],
|
||||
)
|
||||
|
||||
with mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[flow]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
), mock.patch.object(
|
||||
Memory, "reset_all"
|
||||
) as reset_all, mock.patch.object(
|
||||
Memory, "reset"
|
||||
) as reset:
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
|
||||
reset_all.assert_called_once_with()
|
||||
reset.assert_not_called()
|
||||
assert "[Flow (TestFlow)] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_flow_all_memories(mock_get_flows, mock_flow, runner):
|
||||
result = runner.invoke(reset_memories, ["-a"])
|
||||
mock_flow.memory.reset.assert_called_once()
|
||||
@@ -197,16 +230,83 @@ def test_reset_no_crew_or_flow_found(runner):
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
assert "No crew or flow found." in result.output
|
||||
|
||||
|
||||
def test_reset_json_crew_memory(mock_crew, runner, monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "crew.jsonc").write_text("{}")
|
||||
|
||||
with mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.load_crew",
|
||||
return_value=(mock_crew, {}),
|
||||
) as mock_load_crew:
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
|
||||
mock_load_crew.assert_called_once_with(Path("crew.jsonc"))
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_invalid_json_crew_does_not_block_classic_crew(
|
||||
mock_crew, runner, monkeypatch, tmp_path
|
||||
):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "crew.jsonc").write_text("{invalid")
|
||||
|
||||
with mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.load_crew",
|
||||
side_effect=ValueError("invalid JSON"),
|
||||
) as mock_load_crew:
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
|
||||
mock_load_crew.assert_called_once_with(Path("crew.jsonc"))
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
assert "Skipping JSON crew at crew.jsonc: failed to load (invalid JSON)." in result.output
|
||||
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_json_crew_skipped_for_declared_flow_project(
|
||||
mock_crew, runner, monkeypatch, tmp_path
|
||||
):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "crew.jsonc").write_text("{}")
|
||||
(tmp_path / "pyproject.toml").write_text('[tool.crewai]\ntype = "flow"\n')
|
||||
|
||||
with mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.load_crew",
|
||||
return_value=(mock_crew, {}),
|
||||
) as mock_load_crew:
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
|
||||
mock_load_crew.assert_not_called()
|
||||
mock_crew.reset_memories.assert_not_called()
|
||||
assert "No crew or flow found." in result.output
|
||||
|
||||
|
||||
def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner):
|
||||
with mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
@@ -223,6 +323,8 @@ def test_reset_flow_memory_none(runner):
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
assert "[Flow (NoMemFlow)] Memory has been reset." in result.output
|
||||
|
||||
@@ -8,6 +8,7 @@ not silently zero-fill vectors or return empty search results.
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -97,6 +98,33 @@ def test_lancedb_reopened_store_detects_mismatch(lancedb_path: Path) -> None:
|
||||
reopened.search([0.1] * 8)
|
||||
|
||||
|
||||
def test_memory_reset_all_rebuilds_reopened_store_with_new_dimension(
|
||||
lancedb_path: Path,
|
||||
) -> None:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
old = LanceDBStorage(path=str(lancedb_path), vector_dim=4)
|
||||
old.save([_record(4)])
|
||||
|
||||
mem = Memory(
|
||||
storage=str(lancedb_path),
|
||||
llm=MagicMock(),
|
||||
embedder=lambda texts: [[0.1] * 8 for _ in texts],
|
||||
root_scope="/crew/test",
|
||||
)
|
||||
|
||||
mem.reset_all()
|
||||
mem.remember(
|
||||
"new embedder output",
|
||||
scope="/facts",
|
||||
categories=["test"],
|
||||
importance=0.5,
|
||||
)
|
||||
|
||||
assert mem.recall("new embedder output", scope="/facts", depth="shallow")
|
||||
|
||||
|
||||
def test_lancedb_matching_dim_still_works(lancedb_path: Path) -> None:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
|
||||
@@ -954,6 +954,54 @@ def test_remember_many_returns_immediately(tmp_path: Path) -> None:
|
||||
assert mem._storage.count() == 2
|
||||
|
||||
|
||||
def test_reset_all_blocks_new_save_submission_until_reset_completes(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A save cannot be submitted between draining writes and resetting storage."""
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
mem = Memory(
|
||||
storage=str(tmp_path / "db"),
|
||||
llm=MagicMock(),
|
||||
embedder=lambda texts: [[0.1] * 4 for _ in texts],
|
||||
)
|
||||
reset_started = threading.Event()
|
||||
release_reset = threading.Event()
|
||||
submission_returned = threading.Event()
|
||||
order: list[str] = []
|
||||
original_reset = mem._storage.reset
|
||||
|
||||
def blocking_reset(scope_prefix: str | None = None) -> None:
|
||||
order.append("reset-start")
|
||||
reset_started.set()
|
||||
assert release_reset.wait(timeout=2)
|
||||
original_reset(scope_prefix=scope_prefix)
|
||||
order.append("reset-end")
|
||||
|
||||
def submit_save() -> None:
|
||||
mem._submit_save(lambda: order.append("save"))
|
||||
order.append("submit-returned")
|
||||
submission_returned.set()
|
||||
|
||||
monkeypatch.setattr(mem._storage, "reset", blocking_reset)
|
||||
|
||||
reset_thread = threading.Thread(target=mem.reset_all)
|
||||
reset_thread.start()
|
||||
assert reset_started.wait(timeout=2)
|
||||
|
||||
submit_thread = threading.Thread(target=submit_save)
|
||||
submit_thread.start()
|
||||
assert not submission_returned.wait(timeout=0.1)
|
||||
|
||||
release_reset.set()
|
||||
reset_thread.join(timeout=2)
|
||||
submit_thread.join(timeout=2)
|
||||
|
||||
assert not reset_thread.is_alive()
|
||||
assert not submit_thread.is_alive()
|
||||
assert order.index("reset-end") < order.index("submit-returned")
|
||||
|
||||
|
||||
def test_recall_drains_pending_writes(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
"""recall() should automatically wait for pending background saves."""
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
@@ -4584,6 +4584,26 @@ def test_reset_knowledge_with_no_crew_knowledge(researcher, writer):
|
||||
)
|
||||
|
||||
|
||||
def test_reset_memory_uses_full_unified_memory_reset(researcher):
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
process=Process.sequential,
|
||||
tasks=[
|
||||
Task(description="Task 1", expected_output="output", agent=researcher),
|
||||
],
|
||||
memory=True,
|
||||
)
|
||||
|
||||
assert isinstance(crew._memory, Memory)
|
||||
with patch.object(Memory, "reset_all") as reset_all, patch.object(
|
||||
Memory, "reset"
|
||||
) as reset:
|
||||
crew.reset_memories(command_type="memory")
|
||||
|
||||
reset_all.assert_called_once_with()
|
||||
reset.assert_not_called()
|
||||
|
||||
|
||||
def test_reset_knowledge_with_only_crew_knowledge(researcher, writer):
|
||||
mock_ks = MagicMock(spec=Knowledge)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user