Compare commits

..

2 Commits

Author SHA1 Message Date
Joao Moura
d2e5775d7b Fix memory reset review issues 2026-06-16 18:47:26 -07:00
Joao Moura
6c185d22c1 Enhance memory reset functionality and JSON crew handling
- Added `reset_all` method to the `Memory` class to reset the entire memory store, ignoring `root_scope`.
- Updated the `Crew` class to utilize `reset_all` when resetting memory.
- Enhanced the `_reset_flow_memory` function to check for `Memory` instances and call `reset_all` accordingly.
- Introduced helper functions to load JSON crew configurations and handle project declarations, improving the reset command's flexibility.
- Added tests to validate the new JSON crew memory reset behavior and ensure proper handling of declared flow projects.
2026-06-16 16:27:07 -07:00
11 changed files with 300 additions and 183 deletions

View File

@@ -2275,6 +2275,8 @@ class Crew(FlowTrackable, BaseModel):
"""
def default_reset(memory: Any) -> Any:
if isinstance(memory, Memory):
return memory.reset_all()
return memory.reset()
def knowledge_reset(memory: Any) -> Any:

View File

@@ -15,13 +15,10 @@ from crewai.flow.flow_definition import (
FlowConversationalRouterDefinition,
FlowDefinition,
FlowDefinitionDiagnostic,
FlowDictStateDefinition,
FlowHumanFeedbackDefinition,
FlowMethodDefinition,
FlowPersistenceDefinition,
FlowPydanticStateDefinition,
FlowStateDefinition,
FlowUnknownStateDefinition,
_object_ref,
)
from crewai.flow.flow_wrappers import (
@@ -188,11 +185,12 @@ def _build_state_definition(
default = None
if isinstance(state_value, dict):
default = _serialize_static_value(state_value, diagnostics, "state.default")
return FlowDictStateDefinition(default=default)
return FlowStateDefinition(type="dict", default=default)
if isinstance(state_value, type) and issubclass(state_value, PydanticBaseModel):
return FlowPydanticStateDefinition(ref=_state_ref(state_value))
return FlowStateDefinition(type="pydantic", ref=_state_ref(state_value))
if isinstance(state_value, PydanticBaseModel):
return FlowPydanticStateDefinition(
return FlowStateDefinition(
type="pydantic",
ref=_state_ref(state_value),
default=_serialize_static_value(state_value, diagnostics, "state.default"),
)
@@ -203,7 +201,7 @@ def _build_state_definition(
message=f"could not serialize state type {_object_ref(state_value)}",
)
)
return FlowUnknownStateDefinition(ref=_state_ref(state_value))
return FlowStateDefinition(type="unknown", ref=_state_ref(state_value))
def _build_config_definition(

View File

@@ -12,7 +12,7 @@ from __future__ import annotations
import json
import logging
import re
from typing import Annotated, Any, Literal as TypingLiteral, TypeAlias
from typing import Any, Literal as TypingLiteral
from pydantic import (
BaseModel,
@@ -46,18 +46,14 @@ __all__ = [
"FlowDefinition",
"FlowDefinitionCondition",
"FlowDefinitionDiagnostic",
"FlowDictStateDefinition",
"FlowEachActionDefinition",
"FlowEachInnerActionDefinition",
"FlowExpressionActionDefinition",
"FlowHumanFeedbackDefinition",
"FlowJsonSchemaStateDefinition",
"FlowMethodDefinition",
"FlowPersistenceDefinition",
"FlowPydanticStateDefinition",
"FlowStateDefinition",
"FlowToolActionDefinition",
"FlowUnknownStateDefinition",
]
@@ -78,114 +74,13 @@ class FlowDefinitionDiagnostic(BaseModel):
path: str | None = None
class FlowDictStateDefinition(BaseModel):
"""Static description of a plain dictionary Flow state contract."""
class FlowStateDefinition(BaseModel):
"""Static description of a Flow state contract."""
model_config = ConfigDict(extra="forbid")
type: TypingLiteral["dict"] = Field(
default="dict",
description="Plain dictionary state with optional default values.",
examples=["dict"],
)
default: dict[str, Any] | None = Field(
default=None,
description="Default state values applied before kickoff inputs.",
examples=[{"topic": "AI agents", "limit": 3}],
)
class FlowPydanticStateDefinition(BaseModel):
"""Static description of an importable Pydantic Flow state contract."""
model_config = ConfigDict(extra="forbid")
type: TypingLiteral["pydantic"] = Field(
default="pydantic",
description="Importable Pydantic model used as the Flow state type.",
examples=["pydantic"],
)
ref: str | None = Field(
default=None,
description="Import reference for the state model, formatted as module:qualname.",
examples=["my_project.flows:ResearchState"],
)
json_schema: dict[str, Any] | None = Field(
default=None,
description=(
"Fallback JSON Schema used when the Pydantic state ref is unavailable."
),
examples=[
{
"type": "object",
"properties": {"topic": {"type": "string"}},
"required": ["topic"],
}
],
)
default: dict[str, Any] | None = Field(
default=None,
description="Default state values applied before kickoff inputs.",
examples=[{"topic": "AI agents", "limit": 3}],
)
class FlowJsonSchemaStateDefinition(BaseModel):
"""Static description of an inline JSON Schema Flow state contract."""
model_config = ConfigDict(extra="forbid")
type: TypingLiteral["json_schema"] = Field(
default="json_schema",
description="Inline JSON Schema used as the Flow state contract.",
examples=["json_schema"],
)
json_schema: dict[str, Any] = Field(
description="JSON Schema used to validate and document flow state.",
examples=[
{
"type": "object",
"properties": {"topic": {"type": "string"}},
"required": ["topic"],
}
],
)
default: dict[str, Any] | None = Field(
default=None,
description="Default state values applied before kickoff inputs.",
examples=[{"topic": "AI agents", "limit": 3}],
)
class FlowUnknownStateDefinition(BaseModel):
"""Static description of a state contract that could not be serialized."""
model_config = ConfigDict(extra="forbid")
type: TypingLiteral["unknown"] = Field(
default="unknown",
description="Unknown state representation; runtime falls back to dictionary state.",
examples=["unknown"],
)
ref: str | None = Field(
default=None,
description="Best-effort import reference for the unknown state type.",
examples=["my_project.flows:CustomState"],
)
default: dict[str, Any] | None = Field(
default=None,
description="Default state values applied before kickoff inputs.",
examples=[{"topic": "AI agents", "limit": 3}],
)
FlowStateDefinition: TypeAlias = Annotated[
FlowDictStateDefinition
| FlowPydanticStateDefinition
| FlowJsonSchemaStateDefinition
| FlowUnknownStateDefinition,
Field(discriminator="type"),
]
type: TypingLiteral["dict", "pydantic", "json_schema", "unknown"] = "dict"
ref: str | None = None
json_schema: dict[str, Any] | None = None
default: dict[str, Any] | None = None
class FlowConfigDefinition(BaseModel):

View File

@@ -193,24 +193,26 @@ def _build_definition_state_model(
kwargs = dict(state_definition.default or {})
model_class: type[BaseModel] | None = None
state_ref = getattr(state_definition, "ref", None)
if state_ref:
if state_definition.ref:
try:
resolved: Any = resolve_ref(state_ref, field="state")
resolved: Any = resolve_ref(state_definition.ref, field="state")
except Exception:
logger.warning("Could not import state ref %r", state_ref, exc_info=True)
logger.warning(
"Could not import state ref %r", state_definition.ref, exc_info=True
)
else:
if isinstance(resolved, type) and issubclass(resolved, BaseModel):
model_class = resolved
else:
logger.warning("State ref %r is not a pydantic model", state_ref)
logger.warning(
"State ref %r is not a pydantic model", state_definition.ref
)
json_schema = getattr(state_definition, "json_schema", None)
if model_class is None and json_schema:
if model_class is None and state_definition.json_schema:
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
try:
model_class = create_model_from_schema(json_schema)
model_class = create_model_from_schema(state_definition.json_schema)
except Exception:
logger.warning(
"Could not build a state model from the declared json_schema",

View File

@@ -149,6 +149,7 @@ class Memory(BaseModel):
)
_pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list)
_pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_reset_lock: Any = PrivateAttr(default_factory=threading.RLock)
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Memory:
"""Deepcopy that handles unpickleable private attrs (ThreadPoolExecutor, Lock)."""
@@ -168,7 +169,10 @@ class Memory(BaseModel):
)
private = {}
for k, v in (self.__pydantic_private__ or {}).items():
if isinstance(v, (ThreadPoolExecutor, threading.Lock)):
if k in {"_save_pool", "_pending_lock", "_reset_lock"}:
attr = self.__private_attributes__[k]
private[k] = attr.get_default()
elif isinstance(v, (ThreadPoolExecutor, threading.Lock)):
attr = self.__private_attributes__[k]
private[k] = attr.get_default()
else:
@@ -275,22 +279,25 @@ class Memory(BaseModel):
If the pool has been shut down (e.g. after ``close()``), the save
runs synchronously as a fallback so late saves still succeed.
"""
ctx = contextvars.copy_context()
try:
future: Future[Any] = self._save_pool.submit(ctx.run, fn, *args, **kwargs)
except RuntimeError:
# Pool shut down -- run synchronously as fallback
future = Future()
with self._reset_lock:
ctx = contextvars.copy_context()
try:
result = fn(*args, **kwargs)
future.set_result(result)
except Exception as exc:
future.set_exception(exc)
future: Future[Any] = self._save_pool.submit(
ctx.run, fn, *args, **kwargs
)
except RuntimeError:
# Pool shut down -- run synchronously as fallback
future = Future()
try:
result = fn(*args, **kwargs)
future.set_result(result)
except Exception as exc:
future.set_exception(exc)
return future
with self._pending_lock:
self._pending_saves.append(future)
future.add_done_callback(self._on_save_done)
return future
with self._pending_lock:
self._pending_saves.append(future)
future.add_done_callback(self._on_save_done)
return future
def _on_save_done(self, future: Future[Any]) -> None:
"""Remove a completed future from the pending list and emit failure event if needed.
@@ -990,12 +997,20 @@ class Memory(BaseModel):
scope: Scope to reset. If None and root_scope is set, resets only
within root_scope. If None and no root_scope, resets all.
"""
effective_scope = scope
if effective_scope is None and self.root_scope:
effective_scope = self.root_scope
elif effective_scope is not None and self.root_scope:
effective_scope = join_scope_paths(self.root_scope, effective_scope)
self._storage.reset(scope_prefix=effective_scope)
with self._reset_lock:
self.drain_writes()
effective_scope = scope
if effective_scope is None and self.root_scope:
effective_scope = self.root_scope
elif effective_scope is not None and self.root_scope:
effective_scope = join_scope_paths(self.root_scope, effective_scope)
self._storage.reset(scope_prefix=effective_scope)
def reset_all(self) -> None:
"""Reset the entire backing memory store, ignoring ``root_scope``."""
with self._reset_lock:
self.drain_writes()
self._storage.reset(scope_prefix=None)
async def aextract_memories(self, content: str) -> list[str]:
"""Async variant of extract_memories."""

View File

@@ -6,7 +6,10 @@ from typing import Any
import click
from crewai.flow import Flow
from crewai.utilities.project_utils import get_crews, get_flows
from crewai.memory.unified_memory import Memory
from crewai.project.crew_loader import load_crew
from crewai.project.json_loader import find_crew_json_file
from crewai.utilities.project_utils import get_crews, get_flows, read_toml
def _reset_flow_memory(flow: Flow[Any]) -> None:
@@ -23,7 +26,9 @@ def _reset_flow_memory(flow: Flow[Any]) -> None:
if mem is None:
return
try:
if hasattr(mem, "reset"):
if isinstance(mem, Memory):
mem.reset_all()
elif hasattr(mem, "reset"):
mem.reset()
elif hasattr(mem, "_memory") and mem._memory is not None:
mem._memory.reset()
@@ -37,6 +42,38 @@ def _reset_flow_memory(flow: Flow[Any]) -> None:
click.echo(f"Memory reset skipped: {exc}", err=True)
def _current_project_declares_flow() -> bool:
try:
pyproject_data = read_toml()
except Exception:
return False
declared_type: str | None = (
pyproject_data.get("tool", {}).get("crewai", {}).get("type")
)
return declared_type == "flow"
def _get_json_crew() -> Any | None:
"""Load a JSON-first crew from the current project, if present."""
if _current_project_declares_flow():
return None
crew_path = find_crew_json_file()
if crew_path is None:
return None
try:
crew, _ = load_crew(crew_path)
except Exception as exc:
click.echo(
f"Skipping JSON crew at {crew_path}: failed to load ({exc}).",
err=True,
)
return None
return crew
def reset_memories_command(
memory: bool,
knowledge: bool,
@@ -61,6 +98,8 @@ def reset_memories_command(
return
crews = get_crews()
if json_crew := _get_json_crew():
crews.append(json_crew)
flows = get_flows()
if not crews and not flows:

View File

@@ -4,10 +4,12 @@ Non-core CLI tests (train, test, version, deploy, login, flow_add_crew)
have moved to lib/cli/tests/test_cli.py.
"""
from pathlib import Path
from unittest import mock
from click.testing import CliRunner
from crewai.crew import Crew
from crewai.memory.unified_memory import Memory
from crewai_cli.cli import reset_memories
import pytest
@@ -30,6 +32,8 @@ def mock_get_crews(mock_crew):
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
) as mock_get_crew, mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
):
yield mock_get_crew
@@ -170,6 +174,8 @@ def mock_get_flows(mock_flow):
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
) as mock_get_flow, mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
):
yield mock_get_flow
@@ -180,6 +186,33 @@ def test_reset_flow_memory(mock_get_flows, mock_flow, runner):
assert "[Flow (TestFlow)] Memory has been reset." in result.output
def test_reset_flow_unified_memory_uses_full_reset(runner, tmp_path):
flow = mock.Mock()
flow.name = "TestFlow"
flow.memory = Memory(
storage=str(tmp_path / "db"),
llm=mock.Mock(),
embedder=lambda texts: [[0.1] * 4 for _ in texts],
)
with mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[flow]
), mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
), mock.patch.object(
Memory, "reset_all"
) as reset_all, mock.patch.object(
Memory, "reset"
) as reset:
result = runner.invoke(reset_memories, ["-m"])
reset_all.assert_called_once_with()
reset.assert_not_called()
assert "[Flow (TestFlow)] Memory has been reset." in result.output
def test_reset_flow_all_memories(mock_get_flows, mock_flow, runner):
result = runner.invoke(reset_memories, ["-a"])
mock_flow.memory.reset.assert_called_once()
@@ -197,16 +230,83 @@ def test_reset_no_crew_or_flow_found(runner):
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
):
result = runner.invoke(reset_memories, ["-m"])
assert "No crew or flow found." in result.output
def test_reset_json_crew_memory(mock_crew, runner, monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
(tmp_path / "crew.jsonc").write_text("{}")
with mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.load_crew",
return_value=(mock_crew, {}),
) as mock_load_crew:
result = runner.invoke(reset_memories, ["-m"])
mock_load_crew.assert_called_once_with(Path("crew.jsonc"))
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
def test_reset_invalid_json_crew_does_not_block_classic_crew(
mock_crew, runner, monkeypatch, tmp_path
):
monkeypatch.chdir(tmp_path)
(tmp_path / "crew.jsonc").write_text("{invalid")
with mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.load_crew",
side_effect=ValueError("invalid JSON"),
) as mock_load_crew:
result = runner.invoke(reset_memories, ["-m"])
mock_load_crew.assert_called_once_with(Path("crew.jsonc"))
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
assert "Skipping JSON crew at crew.jsonc: failed to load (invalid JSON)." in result.output
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
def test_reset_json_crew_skipped_for_declared_flow_project(
mock_crew, runner, monkeypatch, tmp_path
):
monkeypatch.chdir(tmp_path)
(tmp_path / "crew.jsonc").write_text("{}")
(tmp_path / "pyproject.toml").write_text('[tool.crewai]\ntype = "flow"\n')
with mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.load_crew",
return_value=(mock_crew, {}),
) as mock_load_crew:
result = runner.invoke(reset_memories, ["-m"])
mock_load_crew.assert_not_called()
mock_crew.reset_memories.assert_not_called()
assert "No crew or flow found." in result.output
def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner):
with mock.patch(
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
):
result = runner.invoke(reset_memories, ["-m"])
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
@@ -223,6 +323,8 @@ def test_reset_flow_memory_none(runner):
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
), mock.patch(
"crewai.utilities.reset_memories._get_json_crew", return_value=None
):
result = runner.invoke(reset_memories, ["-m"])
assert "[Flow (NoMemFlow)] Memory has been reset." in result.output

View File

@@ -8,6 +8,7 @@ not silently zero-fill vectors or return empty search results.
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock
import pytest
@@ -97,6 +98,33 @@ def test_lancedb_reopened_store_detects_mismatch(lancedb_path: Path) -> None:
reopened.search([0.1] * 8)
def test_memory_reset_all_rebuilds_reopened_store_with_new_dimension(
lancedb_path: Path,
) -> None:
from crewai.memory.storage.lancedb_storage import LanceDBStorage
from crewai.memory.unified_memory import Memory
old = LanceDBStorage(path=str(lancedb_path), vector_dim=4)
old.save([_record(4)])
mem = Memory(
storage=str(lancedb_path),
llm=MagicMock(),
embedder=lambda texts: [[0.1] * 8 for _ in texts],
root_scope="/crew/test",
)
mem.reset_all()
mem.remember(
"new embedder output",
scope="/facts",
categories=["test"],
importance=0.5,
)
assert mem.recall("new embedder output", scope="/facts", depth="shallow")
def test_lancedb_matching_dim_still_works(lancedb_path: Path) -> None:
from crewai.memory.storage.lancedb_storage import LanceDBStorage

View File

@@ -954,6 +954,54 @@ def test_remember_many_returns_immediately(tmp_path: Path) -> None:
assert mem._storage.count() == 2
def test_reset_all_blocks_new_save_submission_until_reset_completes(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""A save cannot be submitted between draining writes and resetting storage."""
from crewai.memory.unified_memory import Memory
mem = Memory(
storage=str(tmp_path / "db"),
llm=MagicMock(),
embedder=lambda texts: [[0.1] * 4 for _ in texts],
)
reset_started = threading.Event()
release_reset = threading.Event()
submission_returned = threading.Event()
order: list[str] = []
original_reset = mem._storage.reset
def blocking_reset(scope_prefix: str | None = None) -> None:
order.append("reset-start")
reset_started.set()
assert release_reset.wait(timeout=2)
original_reset(scope_prefix=scope_prefix)
order.append("reset-end")
def submit_save() -> None:
mem._submit_save(lambda: order.append("save"))
order.append("submit-returned")
submission_returned.set()
monkeypatch.setattr(mem._storage, "reset", blocking_reset)
reset_thread = threading.Thread(target=mem.reset_all)
reset_thread.start()
assert reset_started.wait(timeout=2)
submit_thread = threading.Thread(target=submit_save)
submit_thread.start()
assert not submission_returned.wait(timeout=0.1)
release_reset.set()
reset_thread.join(timeout=2)
submit_thread.join(timeout=2)
assert not reset_thread.is_alive()
assert not submit_thread.is_alive()
assert order.index("reset-end") < order.index("submit-returned")
def test_recall_drains_pending_writes(tmp_path: Path, mock_embedder: MagicMock) -> None:
"""recall() should automatically wait for pending background saves."""
from crewai.memory.unified_memory import Memory

View File

@@ -4584,6 +4584,26 @@ def test_reset_knowledge_with_no_crew_knowledge(researcher, writer):
)
def test_reset_memory_uses_full_unified_memory_reset(researcher):
crew = Crew(
agents=[researcher],
process=Process.sequential,
tasks=[
Task(description="Task 1", expected_output="output", agent=researcher),
],
memory=True,
)
assert isinstance(crew._memory, Memory)
with patch.object(Memory, "reset_all") as reset_all, patch.object(
Memory, "reset"
) as reset:
crew.reset_memories(command_type="memory")
reset_all.assert_called_once_with()
reset.assert_not_called()
def test_reset_knowledge_with_only_crew_knowledge(researcher, writer):
mock_ks = MagicMock(spec=Knowledge)

View File

@@ -8,7 +8,7 @@ from pathlib import Path
from typing import Annotated, Literal
import pytest
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
import crewai.flow.dsl as flow_dsl
import crewai.flow.flow_definition as flow_definition
@@ -45,51 +45,19 @@ def test_flow_public_exports_are_explicit():
"FlowDefinition",
"FlowDefinitionCondition",
"FlowDefinitionDiagnostic",
"FlowDictStateDefinition",
"FlowEachActionDefinition",
"FlowEachInnerActionDefinition",
"FlowExpressionActionDefinition",
"FlowHumanFeedbackDefinition",
"FlowJsonSchemaStateDefinition",
"FlowMethodDefinition",
"FlowPersistenceDefinition",
"FlowPydanticStateDefinition",
"FlowStateDefinition",
"FlowToolActionDefinition",
"FlowUnknownStateDefinition",
}
assert "build_flow_structure" in flow_visualization.__all__
assert "calculate_node_levels" not in flow_visualization.__all__
def test_flow_state_definition_uses_discriminated_branches():
definition = flow_definition.FlowDefinition.model_validate(
{
"name": "TypedStateFlow",
"state": {
"type": "json_schema",
"json_schema": {"type": "object"},
},
}
)
assert isinstance(
definition.state,
flow_definition.FlowJsonSchemaStateDefinition,
)
with pytest.raises(ValidationError, match="extra_forbidden"):
flow_definition.FlowDefinition.model_validate(
{
"name": "InvalidStateFlow",
"state": {
"type": "dict",
"ref": "my_project.flows:ResearchState",
},
}
)
def test_condition_combinators_return_nested_runtime_tree():
condition = and_("event_a", "event_b", or_("event_c"))