mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 14:52:36 +00:00
Enhance memory reset functionality in CLI commands
- Introduced flow memory reset capabilities in the `reset_memories_command`, allowing for both crew and flow memory resets. - Added a new utility function `_reset_flow_memory` to handle memory resets for individual flow instances, improving modularity and clarity. - Updated the `get_flows` utility to discover flow instances from project files, enhancing the CLI's ability to manage flow states. - Expanded test coverage to validate the new flow memory reset features, ensuring robust functionality and error handling.
This commit is contained in:
@@ -2,7 +2,30 @@ import subprocess
|
||||
|
||||
import click
|
||||
|
||||
from crewai.cli.utils import get_crews
|
||||
from crewai.cli.utils import get_crews, get_flows
|
||||
from crewai.flow import Flow
|
||||
|
||||
|
||||
def _reset_flow_memory(flow: Flow) -> None:
|
||||
"""Reset memory for a single flow instance.
|
||||
|
||||
Handles Memory, MemoryScope (both have .reset()), and MemorySlice
|
||||
(delegates to the underlying ._memory). Silently succeeds when the
|
||||
storage directory does not exist yet (nothing to reset).
|
||||
|
||||
Args:
|
||||
flow: The flow instance whose memory should be reset.
|
||||
"""
|
||||
mem = flow.memory
|
||||
if mem is None:
|
||||
return
|
||||
try:
|
||||
if hasattr(mem, "reset"):
|
||||
mem.reset()
|
||||
elif hasattr(mem, "_memory") and hasattr(mem._memory, "reset"):
|
||||
mem._memory.reset()
|
||||
except (FileNotFoundError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
@@ -12,7 +35,7 @@ def reset_memories_command(
|
||||
kickoff_outputs: bool,
|
||||
all: bool,
|
||||
) -> None:
|
||||
"""Reset the crew memories.
|
||||
"""Reset the crew and flow memories.
|
||||
|
||||
Args:
|
||||
memory: Whether to reset the unified memory.
|
||||
@@ -29,8 +52,11 @@ def reset_memories_command(
|
||||
return
|
||||
|
||||
crews = get_crews()
|
||||
if not crews:
|
||||
raise ValueError("No crew found.")
|
||||
flows = get_flows()
|
||||
|
||||
if not crews and not flows:
|
||||
raise ValueError("No crew or flow found.")
|
||||
|
||||
for crew in crews:
|
||||
if all:
|
||||
crew.reset_memories(command_type="all")
|
||||
@@ -59,6 +85,20 @@ def reset_memories_command(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Agents knowledge has been reset."
|
||||
)
|
||||
|
||||
for flow in flows:
|
||||
flow_name = flow.name or flow.__class__.__name__
|
||||
if all:
|
||||
_reset_flow_memory(flow)
|
||||
click.echo(
|
||||
f"[Flow ({flow_name})] Reset memories command has been completed."
|
||||
)
|
||||
continue
|
||||
if memory:
|
||||
_reset_flow_memory(flow)
|
||||
click.echo(
|
||||
f"[Flow ({flow_name})] Memory has been reset."
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||
click.echo(e.output, err=True)
|
||||
|
||||
@@ -386,6 +386,109 @@ def fetch_crews(module_attr: Any) -> list[Crew]:
|
||||
return crew_instances
|
||||
|
||||
|
||||
def get_flow_instance(module_attr: Any) -> Flow | None:
|
||||
"""Check if a module attribute is a user-defined Flow subclass and return an instance.
|
||||
|
||||
Args:
|
||||
module_attr: An attribute from a loaded module.
|
||||
|
||||
Returns:
|
||||
A Flow instance if the attribute is a valid user-defined Flow subclass,
|
||||
None otherwise.
|
||||
"""
|
||||
if (
|
||||
isinstance(module_attr, type)
|
||||
and issubclass(module_attr, Flow)
|
||||
and module_attr is not Flow
|
||||
):
|
||||
try:
|
||||
return module_attr()
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
_SKIP_DIRS = frozenset(
|
||||
{".venv", "venv", ".git", "__pycache__", "node_modules", ".tox", ".nox"}
|
||||
)
|
||||
|
||||
|
||||
def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
"""Get the flow instances from project files.
|
||||
|
||||
Walks the project directory looking for files matching ``flow_path``
|
||||
(default ``main.py``), loads each module, and extracts Flow subclass
|
||||
instances. Directories that are clearly not user source code (virtual
|
||||
environments, ``.git``, etc.) are pruned to avoid noisy import errors.
|
||||
|
||||
Args:
|
||||
flow_path: Filename to search for (default ``main.py``).
|
||||
|
||||
Returns:
|
||||
A list of discovered Flow instances.
|
||||
"""
|
||||
flow_instances: list[Flow] = []
|
||||
try:
|
||||
current_dir = os.getcwd()
|
||||
if current_dir not in sys.path:
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
src_dir = os.path.join(current_dir, "src")
|
||||
if os.path.isdir(src_dir) and src_dir not in sys.path:
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
search_paths = [".", "src"] if os.path.isdir("src") else ["."]
|
||||
|
||||
for search_path in search_paths:
|
||||
for root, dirs, files in os.walk(search_path):
|
||||
dirs[:] = [
|
||||
d
|
||||
for d in dirs
|
||||
if d not in _SKIP_DIRS and not d.startswith(".")
|
||||
]
|
||||
if flow_path in files and "cli/templates" not in root:
|
||||
file_os_path = os.path.join(root, flow_path)
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"flow_module", file_os_path
|
||||
)
|
||||
if not spec or not spec.loader:
|
||||
continue
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
for attr_name in dir(module):
|
||||
module_attr = getattr(module, attr_name)
|
||||
try:
|
||||
if flow_instance := get_flow_instance(
|
||||
module_attr
|
||||
):
|
||||
flow_instances.append(flow_instance)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if flow_instances:
|
||||
break
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
except (ImportError, AttributeError):
|
||||
continue
|
||||
|
||||
if flow_instances:
|
||||
break
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return flow_instances
|
||||
|
||||
|
||||
def is_valid_tool(obj: Any) -> bool:
|
||||
from crewai.tools.base_tool import Tool
|
||||
|
||||
|
||||
@@ -66,7 +66,9 @@ def mock_crew():
|
||||
def mock_get_crews(mock_crew):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
|
||||
) as mock_get_crew:
|
||||
) as mock_get_crew, mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[]
|
||||
):
|
||||
yield mock_get_crew
|
||||
|
||||
|
||||
@@ -193,6 +195,79 @@ def test_reset_memory_from_many_crews(mock_get_crews, runner):
|
||||
assert call_count == 2, "reset_memories should have been called twice"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow():
|
||||
_mock = mock.Mock()
|
||||
_mock.name = "TestFlow"
|
||||
_mock.memory = mock.Mock()
|
||||
_mock.memory.reset = mock.Mock()
|
||||
return _mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_flows(mock_flow):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
|
||||
) as mock_get_flow, mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[]
|
||||
):
|
||||
yield mock_get_flow
|
||||
|
||||
|
||||
def test_reset_flow_memory(mock_get_flows, mock_flow, runner):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
mock_flow.memory.reset.assert_called_once()
|
||||
assert "[Flow (TestFlow)] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_flow_all_memories(mock_get_flows, mock_flow, runner):
|
||||
result = runner.invoke(reset_memories, ["-a"])
|
||||
mock_flow.memory.reset.assert_called_once()
|
||||
assert "[Flow (TestFlow)] Reset memories command has been completed." in result.output
|
||||
|
||||
|
||||
def test_reset_flow_knowledge_no_effect(mock_get_flows, mock_flow, runner):
|
||||
result = runner.invoke(reset_memories, ["--knowledge"])
|
||||
mock_flow.memory.reset.assert_not_called()
|
||||
assert "[Flow (TestFlow)]" not in result.output
|
||||
|
||||
|
||||
def test_reset_no_crew_or_flow_found(runner):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[]
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
assert "No crew or flow found." in result.output
|
||||
|
||||
|
||||
def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
|
||||
), mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
mock_flow.memory.reset.assert_called_once()
|
||||
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
|
||||
assert "[Flow (TestFlow)] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_flow_memory_none(runner):
|
||||
mock_flow = mock.Mock()
|
||||
mock_flow.name = "NoMemFlow"
|
||||
mock_flow.memory = None
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
assert "[Flow (NoMemFlow)] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_no_memory_flags(runner):
|
||||
result = runner.invoke(
|
||||
reset_memories,
|
||||
|
||||
Reference in New Issue
Block a user