mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-02-09 07:28:15 +00:00
Compare commits
13 Commits
main
...
joaomdmour
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
01a136c9e5 | ||
|
|
1f2a6c26e3 | ||
|
|
fd1763a447 | ||
|
|
0c9be461a6 | ||
|
|
8ae5baea48 | ||
|
|
c119d9aa34 | ||
|
|
d536674a1f | ||
|
|
00fb9502ce | ||
|
|
92aefd4502 | ||
|
|
797dbe2123 | ||
|
|
41f080240a | ||
|
|
9f9ce7ceb1 | ||
|
|
a060e7f442 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -27,3 +27,5 @@ conceptual_plan.md
|
||||
build_image
|
||||
chromadb-*.lock
|
||||
.claude
|
||||
.crewai/memory
|
||||
blogs/*
|
||||
|
||||
@@ -11,7 +11,11 @@ from typing import Any
|
||||
from dotenv import load_dotenv
|
||||
import pytest
|
||||
from vcr.request import Request # type: ignore[import-untyped]
|
||||
import vcr.stubs.httpx_stubs as httpx_stubs # type: ignore[import-untyped]
|
||||
|
||||
try:
|
||||
import vcr.stubs.httpx_stubs as httpx_stubs # type: ignore[import-untyped]
|
||||
except ModuleNotFoundError:
|
||||
import vcr.stubs.httpcore_stubs as httpx_stubs # type: ignore[import-untyped]
|
||||
|
||||
|
||||
env_test_path = Path(__file__).parent / ".env.test"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -26,6 +26,8 @@ dependencies = [
|
||||
# Authentication and Security
|
||||
"python-dotenv~=1.1.1",
|
||||
"pyjwt>=2.9.0,<3",
|
||||
# TUI
|
||||
"textual>=7.5.0",
|
||||
# Configuration and Utils
|
||||
"click~=8.1.7",
|
||||
"appdirs~=1.4.4",
|
||||
@@ -39,6 +41,7 @@ dependencies = [
|
||||
"mcp~=1.23.1",
|
||||
"uv~=0.9.13",
|
||||
"aiosqlite~=0.21.0",
|
||||
"lancedb>=0.4.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -10,6 +10,7 @@ from crewai.flow.flow import Flow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
@@ -80,6 +81,7 @@ __all__ = [
|
||||
"Flow",
|
||||
"Knowledge",
|
||||
"LLMGuardrail",
|
||||
"Memory",
|
||||
"Process",
|
||||
"Task",
|
||||
"TaskOutput",
|
||||
|
||||
@@ -71,7 +71,6 @@ from crewai.mcp import (
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
@@ -311,19 +310,12 @@ class Agent(BaseAgent):
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {e!s}") from e
|
||||
|
||||
def _is_any_available_memory(self) -> bool:
|
||||
"""Check if any memory is available."""
|
||||
if not self.crew:
|
||||
return False
|
||||
|
||||
memory_attributes = [
|
||||
"memory",
|
||||
"_short_term_memory",
|
||||
"_long_term_memory",
|
||||
"_entity_memory",
|
||||
"_external_memory",
|
||||
]
|
||||
|
||||
return any(getattr(self.crew, attr) for attr in memory_attributes)
|
||||
"""Check if unified memory is available (agent or crew)."""
|
||||
if getattr(self, "memory", None):
|
||||
return True
|
||||
if self.crew and getattr(self.crew, "_memory", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _supports_native_tool_calling(self, tools: list[BaseTool]) -> bool:
|
||||
"""Check if the LLM supports native function calling with the given tools.
|
||||
@@ -387,15 +379,16 @@ class Agent(BaseAgent):
|
||||
memory = ""
|
||||
|
||||
try:
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
unified_memory = getattr(self, "memory", None) or (
|
||||
getattr(self.crew, "_memory", None) if self.crew else None
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context or "")
|
||||
if unified_memory is not None:
|
||||
query = task.description
|
||||
matches = unified_memory.recall(query, limit=10, depth="shallow")
|
||||
if matches:
|
||||
memory = "Relevant memories:\n" + "\n".join(
|
||||
f"- {m.record.content}" for m in matches
|
||||
)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
@@ -624,17 +617,16 @@ class Agent(BaseAgent):
|
||||
memory = ""
|
||||
|
||||
try:
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = await contextual_memory.abuild_context_for_task(
|
||||
task, context or ""
|
||||
unified_memory = getattr(self, "memory", None) or (
|
||||
getattr(self.crew, "_memory", None) if self.crew else None
|
||||
)
|
||||
if unified_memory is not None:
|
||||
query = task.description
|
||||
matches = unified_memory.recall(query, limit=10, depth="shallow")
|
||||
if matches:
|
||||
memory = "Relevant memories:\n" + "\n".join(
|
||||
f"- {m.record.content}" for m in matches
|
||||
)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
|
||||
@@ -199,6 +199,10 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
|
||||
)
|
||||
memory: Any = Field(
|
||||
default=None,
|
||||
description="Memory instance for this agent (Memory, MemoryScope, or MemorySlice). If not set, falls back to crew memory.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.agents.parser import AgentFinish
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.utilities.converter import ConverterError
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
@@ -30,110 +25,29 @@ class CrewAgentExecutorMixin:
|
||||
_i18n: I18N
|
||||
_printer: Printer = Printer()
|
||||
|
||||
def _create_short_term_memory(self, output: AgentFinish) -> None:
|
||||
"""Create and save a short-term memory item if conditions are met."""
|
||||
def _save_to_memory(self, output: AgentFinish) -> None:
|
||||
"""Save task result to unified memory (memory or crew._memory)."""
|
||||
memory = getattr(self.agent, "memory", None) or (
|
||||
getattr(self.crew, "_memory", None) if self.crew else None
|
||||
)
|
||||
if memory is None or not self.task:
|
||||
return
|
||||
if (
|
||||
self.crew
|
||||
and self.agent
|
||||
and self.task
|
||||
and f"Action: {sanitize_tool_name('Delegate work to coworker')}"
|
||||
not in output.text
|
||||
f"Action: {sanitize_tool_name('Delegate work to coworker')}"
|
||||
in output.text
|
||||
):
|
||||
try:
|
||||
if (
|
||||
hasattr(self.crew, "_short_term_memory")
|
||||
and self.crew._short_term_memory
|
||||
):
|
||||
self.crew._short_term_memory.save(
|
||||
value=output.text,
|
||||
metadata={
|
||||
"observation": self.task.description,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
self.agent._logger.log(
|
||||
"error", f"Failed to add to short term memory: {e}"
|
||||
)
|
||||
|
||||
def _create_external_memory(self, output: AgentFinish) -> None:
|
||||
"""Create and save a external-term memory item if conditions are met."""
|
||||
if (
|
||||
self.crew
|
||||
and self.agent
|
||||
and self.task
|
||||
and hasattr(self.crew, "_external_memory")
|
||||
and self.crew._external_memory
|
||||
):
|
||||
try:
|
||||
self.crew._external_memory.save(
|
||||
value=output.text,
|
||||
metadata={
|
||||
"description": self.task.description,
|
||||
"messages": self.messages,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
self.agent._logger.log(
|
||||
"error", f"Failed to add to external memory: {e}"
|
||||
)
|
||||
|
||||
def _create_long_term_memory(self, output: AgentFinish) -> None:
|
||||
"""Create and save long-term and entity memory items based on evaluation."""
|
||||
if (
|
||||
self.crew
|
||||
and self.crew._long_term_memory
|
||||
and self.crew._entity_memory
|
||||
and self.task
|
||||
and self.agent
|
||||
):
|
||||
try:
|
||||
ltm_agent = TaskEvaluator(self.agent)
|
||||
evaluation = ltm_agent.evaluate(self.task, output.text)
|
||||
|
||||
if isinstance(evaluation, ConverterError):
|
||||
return
|
||||
|
||||
long_term_memory = LongTermMemoryItem(
|
||||
task=self.task.description,
|
||||
agent=self.agent.role,
|
||||
quality=evaluation.quality,
|
||||
datetime=str(time.time()),
|
||||
expected_output=self.task.expected_output,
|
||||
metadata={
|
||||
"suggestions": evaluation.suggestions,
|
||||
"quality": evaluation.quality,
|
||||
},
|
||||
)
|
||||
self.crew._long_term_memory.save(long_term_memory)
|
||||
|
||||
entity_memories = [
|
||||
EntityMemoryItem(
|
||||
name=entity.name,
|
||||
type=entity.type,
|
||||
description=entity.description,
|
||||
relationships="\n".join(
|
||||
[f"- {r}" for r in entity.relationships]
|
||||
),
|
||||
)
|
||||
for entity in evaluation.entities
|
||||
]
|
||||
if entity_memories:
|
||||
self.crew._entity_memory.save(entity_memories)
|
||||
except AttributeError as e:
|
||||
self.agent._logger.log(
|
||||
"error", f"Missing attributes for long term memory: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.agent._logger.log(
|
||||
"error", f"Failed to add to long term memory: {e}"
|
||||
)
|
||||
elif (
|
||||
self.crew
|
||||
and self.crew._long_term_memory
|
||||
and self.crew._entity_memory is None
|
||||
):
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(
|
||||
content="Long term memory is enabled, but entity memory is not enabled. Please configure entity memory or set memory=True to automatically enable it.",
|
||||
color="bold_yellow",
|
||||
)
|
||||
return
|
||||
try:
|
||||
raw = (
|
||||
f"Task: {self.task.description}\n"
|
||||
f"Agent: {self.agent.role}\n"
|
||||
f"Expected result: {self.task.expected_output}\n"
|
||||
f"Result: {output.text}"
|
||||
)
|
||||
extracted = memory.extract_memories(raw)
|
||||
for mem in extracted:
|
||||
memory.remember(mem)
|
||||
except Exception as e:
|
||||
self.agent._logger.log(
|
||||
"error", f"Failed to save to memory: {e}"
|
||||
)
|
||||
|
||||
@@ -234,9 +234,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
|
||||
self._create_short_term_memory(formatted_answer)
|
||||
self._create_long_term_memory(formatted_answer)
|
||||
self._create_external_memory(formatted_answer)
|
||||
self._save_to_memory(formatted_answer)
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
def _inject_multimodal_files(self, inputs: dict[str, Any] | None = None) -> None:
|
||||
@@ -1011,9 +1009,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = await self._ahandle_human_feedback(formatted_answer)
|
||||
|
||||
self._create_short_term_memory(formatted_answer)
|
||||
self._create_long_term_memory(formatted_answer)
|
||||
self._create_external_memory(formatted_answer)
|
||||
self._save_to_memory(formatted_answer)
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
async def _ainvoke_loop(self) -> AgentFinish:
|
||||
|
||||
@@ -179,9 +179,19 @@ def log_tasks_outputs() -> None:
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option("-l", "--long", is_flag=True, help="Reset LONG TERM memory")
|
||||
@click.option("-s", "--short", is_flag=True, help="Reset SHORT TERM memory")
|
||||
@click.option("-e", "--entities", is_flag=True, help="Reset ENTITIES memory")
|
||||
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
|
||||
@click.option(
|
||||
"-l", "--long", is_flag=True, hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-s", "--short", is_flag=True, hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-e", "--entities", is_flag=True, hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
|
||||
@click.option(
|
||||
"-akn", "--agent-knowledge", is_flag=True, help="Reset AGENT KNOWLEDGE storage"
|
||||
@@ -191,6 +201,7 @@ def log_tasks_outputs() -> None:
|
||||
)
|
||||
@click.option("-a", "--all", is_flag=True, help="Reset ALL memories")
|
||||
def reset_memories(
|
||||
memory: bool,
|
||||
long: bool,
|
||||
short: bool,
|
||||
entities: bool,
|
||||
@@ -200,13 +211,22 @@ def reset_memories(
|
||||
all: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs, knowledge, agent_knowledge). This will delete all the data saved.
|
||||
Reset the crew memories (memory, knowledge, agent_knowledge, kickoff_outputs). This will delete all the data saved.
|
||||
"""
|
||||
try:
|
||||
# Treat legacy flags as --memory with a deprecation warning
|
||||
if long or short or entities:
|
||||
legacy_used = [
|
||||
f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v
|
||||
]
|
||||
click.echo(
|
||||
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
|
||||
"deprecated. Use --memory (-m) instead. All memory is now unified."
|
||||
)
|
||||
memory = True
|
||||
|
||||
memory_types = [
|
||||
long,
|
||||
short,
|
||||
entities,
|
||||
memory,
|
||||
knowledge,
|
||||
agent_knowledge,
|
||||
kickoff_outputs,
|
||||
@@ -218,12 +238,33 @@ def reset_memories(
|
||||
)
|
||||
return
|
||||
reset_memories_command(
|
||||
long, short, entities, knowledge, agent_knowledge, kickoff_outputs, all
|
||||
memory, knowledge, agent_knowledge, kickoff_outputs, all
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"An error occurred while resetting memories: {e}", err=True)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"--storage-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LanceDB memory directory. If omitted, uses ./.crewai/memory.",
|
||||
)
|
||||
def memory(storage_path: str | None) -> None:
|
||||
"""Open the Memory TUI to browse scopes and recall memories."""
|
||||
try:
|
||||
from crewai.cli.memory_tui import MemoryTUI
|
||||
except ImportError as exc:
|
||||
click.echo(
|
||||
"Textual is required for the memory TUI but could not be imported. "
|
||||
"Try reinstalling crewai or: pip install textual"
|
||||
)
|
||||
raise SystemExit(1) from exc
|
||||
app = MemoryTUI(storage_path=storage_path)
|
||||
app.run()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"-n",
|
||||
|
||||
257
lib/crewai/src/crewai/cli/memory_tui.py
Normal file
257
lib/crewai/src/crewai/cli/memory_tui.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Textual TUI for browsing and recalling unified memory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.containers import Horizontal
|
||||
from textual.widgets import Footer, Header, Input, Static, Tree
|
||||
|
||||
|
||||
# -- CrewAI brand palette --
|
||||
_PRIMARY = "#eb6658" # coral
|
||||
_SECONDARY = "#1F7982" # teal
|
||||
_TERTIARY = "#ffffff" # white
|
||||
|
||||
|
||||
def _format_scope_info(info: Any) -> str:
|
||||
"""Format ScopeInfo with Rich markup."""
|
||||
return (
|
||||
f"[bold {_PRIMARY}]{info.path}[/]\n\n"
|
||||
f"[dim]Records:[/] [bold]{info.record_count}[/]\n"
|
||||
f"[dim]Categories:[/] {', '.join(info.categories) or 'none'}\n"
|
||||
f"[dim]Oldest:[/] {info.oldest_record or '-'}\n"
|
||||
f"[dim]Newest:[/] {info.newest_record or '-'}\n"
|
||||
f"[dim]Children:[/] {', '.join(info.child_scopes) or 'none'}"
|
||||
)
|
||||
|
||||
|
||||
class MemoryTUI(App[None]):
|
||||
"""TUI to browse memory scopes and run recall queries."""
|
||||
|
||||
TITLE = "CrewAI Memory"
|
||||
SUB_TITLE = "Browse scopes and recall memories"
|
||||
|
||||
PAGE_SIZE: ClassVar[int] = 20
|
||||
|
||||
BINDINGS: ClassVar[list[tuple[str, str, str]]] = [
|
||||
("n", "next_page", "Next page"),
|
||||
("p", "prev_page", "Prev page"),
|
||||
]
|
||||
|
||||
CSS = f"""
|
||||
Header {{
|
||||
background: {_PRIMARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Footer {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Footer > .footer-key--key {{
|
||||
background: {_PRIMARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Horizontal {{
|
||||
height: 1fr;
|
||||
}}
|
||||
#scope-tree {{
|
||||
width: 30%;
|
||||
padding: 1 2;
|
||||
background: {_SECONDARY} 8%;
|
||||
border-right: solid {_SECONDARY};
|
||||
}}
|
||||
#scope-tree:focus > .tree--cursor {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
#scope-tree > .tree--guides {{
|
||||
color: {_SECONDARY} 50%;
|
||||
}}
|
||||
#scope-tree > .tree--guides-hover {{
|
||||
color: {_PRIMARY};
|
||||
}}
|
||||
#scope-tree > .tree--guides-selected {{
|
||||
color: {_SECONDARY};
|
||||
}}
|
||||
#info-panel {{
|
||||
width: 70%;
|
||||
padding: 1 2;
|
||||
overflow-y: auto;
|
||||
}}
|
||||
#info-panel LoadingIndicator {{
|
||||
color: {_PRIMARY};
|
||||
}}
|
||||
#recall-input {{
|
||||
margin: 0 1 1 1;
|
||||
border: tall {_SECONDARY};
|
||||
}}
|
||||
#recall-input:focus {{
|
||||
border: tall {_PRIMARY};
|
||||
}}
|
||||
"""
|
||||
|
||||
def __init__(self, storage_path: str | None = None) -> None:
|
||||
super().__init__()
|
||||
self._memory: Any = None
|
||||
self._init_error: str | None = None
|
||||
self._selected_scope: str = "/"
|
||||
self._entries: list[Any] = []
|
||||
self._page: int = 0
|
||||
self._last_scope_info: Any = None
|
||||
try:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
storage = LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
||||
self._memory = Memory(storage=storage)
|
||||
except Exception as e:
|
||||
self._init_error = str(e)
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=False)
|
||||
with Horizontal():
|
||||
yield self._build_scope_tree()
|
||||
initial = (
|
||||
self._init_error
|
||||
if self._init_error
|
||||
else "Select a scope or type a recall query."
|
||||
)
|
||||
yield Static(initial, id="info-panel")
|
||||
yield Input(
|
||||
placeholder="Type a query and press Enter to recall...",
|
||||
id="recall-input",
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
def _build_scope_tree(self) -> Tree[str]:
|
||||
tree: Tree[str] = Tree("/", id="scope-tree")
|
||||
if self._memory is None:
|
||||
tree.root.data = "/"
|
||||
tree.root.label = "/ (0 records)"
|
||||
return tree
|
||||
info = self._memory.info("/")
|
||||
tree.root.label = f"/ ({info.record_count} records)"
|
||||
tree.root.data = "/"
|
||||
self._add_children(tree.root, "/", depth=0, max_depth=3)
|
||||
tree.root.expand()
|
||||
return tree
|
||||
|
||||
def _add_children(
|
||||
self,
|
||||
parent_node: Tree.Node[str],
|
||||
path: str,
|
||||
depth: int,
|
||||
max_depth: int,
|
||||
) -> None:
|
||||
if depth >= max_depth or self._memory is None:
|
||||
return
|
||||
info = self._memory.info(path)
|
||||
for child in info.child_scopes:
|
||||
child_info = self._memory.info(child)
|
||||
label = f"{child} ({child_info.record_count})"
|
||||
node = parent_node.add(label, data=child)
|
||||
self._add_children(node, child, depth + 1, max_depth)
|
||||
|
||||
def on_tree_node_selected(self, event: Tree.NodeSelected[str]) -> None:
|
||||
path = event.node.data if event.node.data is not None else "/"
|
||||
self._selected_scope = path
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
if self._memory is None:
|
||||
panel.update(self._init_error or "No memory loaded.")
|
||||
return
|
||||
info = self._memory.info(path)
|
||||
self._last_scope_info = info
|
||||
self._entries = self._memory.list_records(scope=path, limit=200)
|
||||
self._page = 0
|
||||
self._render_panel()
|
||||
|
||||
def _render_panel(self) -> None:
|
||||
"""Refresh info panel with scope info and current page of entries."""
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
if self._last_scope_info is None:
|
||||
return
|
||||
lines: list[str] = [_format_scope_info(self._last_scope_info)]
|
||||
total = len(self._entries)
|
||||
if total == 0:
|
||||
lines.append("\n[dim]No entries in this scope.[/]")
|
||||
else:
|
||||
total_pages = (total + self.PAGE_SIZE - 1) // self.PAGE_SIZE
|
||||
page_num = self._page + 1
|
||||
lines.append(
|
||||
f"\n[bold {_PRIMARY}]{'─' * 44}[/]\n"
|
||||
f"[bold]Entries[/] [dim](page {page_num} of {total_pages})[/]\n"
|
||||
)
|
||||
start = self._page * self.PAGE_SIZE
|
||||
end = min(start + self.PAGE_SIZE, total)
|
||||
for record in self._entries[start:end]:
|
||||
date_str = record.created_at.strftime("%Y-%m-%d")
|
||||
preview = (record.content[:100] + "…") if len(record.content) > 100 else record.content
|
||||
lines.append(
|
||||
f"[{_SECONDARY}]{date_str}[/] "
|
||||
f"[bold]{record.importance:.1f}[/] "
|
||||
f"[dim]{preview}[/]"
|
||||
)
|
||||
if total_pages > 1:
|
||||
lines.append("\n[dim]\\[n] next \\[p] prev[/]")
|
||||
panel.update("\n".join(lines))
|
||||
|
||||
def action_next_page(self) -> None:
|
||||
"""Go to next page of entries."""
|
||||
total_pages = (len(self._entries) + self.PAGE_SIZE - 1) // self.PAGE_SIZE
|
||||
if total_pages <= 0:
|
||||
return
|
||||
self._page = min(self._page + 1, total_pages - 1)
|
||||
self._render_panel()
|
||||
|
||||
def action_prev_page(self) -> None:
|
||||
"""Go to previous page of entries."""
|
||||
if self._page <= 0:
|
||||
return
|
||||
self._page -= 1
|
||||
self._render_panel()
|
||||
|
||||
def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||
query = event.value.strip()
|
||||
if not query:
|
||||
return
|
||||
if self._memory is None:
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
panel.update(self._init_error or "No memory loaded. Cannot recall.")
|
||||
return
|
||||
self.run_worker(self._do_recall(query), exclusive=True)
|
||||
|
||||
async def _do_recall(self, query: str) -> None:
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
panel.loading = True
|
||||
try:
|
||||
scope = (
|
||||
self._selected_scope
|
||||
if self._selected_scope != "/"
|
||||
else None
|
||||
)
|
||||
matches = self._memory.recall(
|
||||
query,
|
||||
scope=scope,
|
||||
limit=10,
|
||||
depth="shallow",
|
||||
)
|
||||
if not matches:
|
||||
panel.update("[dim]No memories found.[/]")
|
||||
return
|
||||
lines: list[str] = []
|
||||
for m in matches:
|
||||
content = m.record.content
|
||||
score_color = _PRIMARY if m.score >= 0.5 else "dim"
|
||||
lines.append(f"[{score_color}]\\[{m.score:.2f}][/] {content[:120]}")
|
||||
lines.append(
|
||||
f" [dim]scope={m.record.scope} "
|
||||
f"importance={m.record.importance:.1f}[/]"
|
||||
)
|
||||
lines.append("")
|
||||
panel.update("\n".join(lines))
|
||||
except Exception as e:
|
||||
panel.update(f"[bold red]Error:[/] {e}")
|
||||
finally:
|
||||
panel.loading = False
|
||||
@@ -6,31 +6,23 @@ from crewai.cli.utils import get_crews
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
long,
|
||||
short,
|
||||
entity,
|
||||
knowledge,
|
||||
agent_knowledge,
|
||||
kickoff_outputs,
|
||||
all,
|
||||
memory: bool,
|
||||
knowledge: bool,
|
||||
agent_knowledge: bool,
|
||||
kickoff_outputs: bool,
|
||||
all: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Reset the crew memories.
|
||||
"""Reset the crew memories.
|
||||
|
||||
Args:
|
||||
long (bool): Whether to reset the long-term memory.
|
||||
short (bool): Whether to reset the short-term memory.
|
||||
entity (bool): Whether to reset the entity memory.
|
||||
kickoff_outputs (bool): Whether to reset the latest kickoff task outputs.
|
||||
all (bool): Whether to reset all memories.
|
||||
knowledge (bool): Whether to reset the knowledge.
|
||||
agent_knowledge (bool): Whether to reset the agents knowledge.
|
||||
memory: Whether to reset the unified memory.
|
||||
knowledge: Whether to reset the knowledge.
|
||||
agent_knowledge: Whether to reset the agents knowledge.
|
||||
kickoff_outputs: Whether to reset the latest kickoff task outputs.
|
||||
all: Whether to reset all memories.
|
||||
"""
|
||||
|
||||
try:
|
||||
if not any(
|
||||
[long, short, entity, kickoff_outputs, knowledge, agent_knowledge, all]
|
||||
):
|
||||
if not any([memory, kickoff_outputs, knowledge, agent_knowledge, all]):
|
||||
click.echo(
|
||||
"No memory type specified. Please specify at least one type to reset."
|
||||
)
|
||||
@@ -46,20 +38,10 @@ def reset_memories_command(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed."
|
||||
)
|
||||
continue
|
||||
if long:
|
||||
crew.reset_memories(command_type="long")
|
||||
if memory:
|
||||
crew.reset_memories(command_type="memory")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset."
|
||||
)
|
||||
if short:
|
||||
crew.reset_memories(command_type="short")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset."
|
||||
)
|
||||
if entity:
|
||||
crew.reset_memories(command_type="entity")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Memory has been reset."
|
||||
)
|
||||
if kickoff_outputs:
|
||||
crew.reset_memories(command_type="kickoff_outputs")
|
||||
|
||||
@@ -83,10 +83,6 @@ from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
@@ -174,10 +170,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler)
|
||||
_short_term_memory: InstanceOf[ShortTermMemory] | None = PrivateAttr()
|
||||
_long_term_memory: InstanceOf[LongTermMemory] | None = PrivateAttr()
|
||||
_entity_memory: InstanceOf[EntityMemory] | None = PrivateAttr()
|
||||
_external_memory: InstanceOf[ExternalMemory] | None = PrivateAttr()
|
||||
_memory: Any = PrivateAttr(default=None) # Unified Memory | MemoryScope
|
||||
_train: bool | None = PrivateAttr(default=False)
|
||||
_train_iteration: int | None = PrivateAttr()
|
||||
_inputs: dict[str, Any] | None = PrivateAttr(default=None)
|
||||
@@ -194,25 +187,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
agents: list[BaseAgent] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool = Field(
|
||||
memory: bool | Any = Field(
|
||||
default=False,
|
||||
description="If crew should use memory to store memories of it's execution",
|
||||
)
|
||||
short_term_memory: InstanceOf[ShortTermMemory] | None = Field(
|
||||
default=None,
|
||||
description="An Instance of the ShortTermMemory to be used by the Crew",
|
||||
)
|
||||
long_term_memory: InstanceOf[LongTermMemory] | None = Field(
|
||||
default=None,
|
||||
description="An Instance of the LongTermMemory to be used by the Crew",
|
||||
)
|
||||
entity_memory: InstanceOf[EntityMemory] | None = Field(
|
||||
default=None,
|
||||
description="An Instance of the EntityMemory to be used by the Crew",
|
||||
)
|
||||
external_memory: InstanceOf[ExternalMemory] | None = Field(
|
||||
default=None,
|
||||
description="An Instance of the ExternalMemory to be used by the Crew",
|
||||
description=(
|
||||
"Enable crew memory. Pass True for default Memory(), "
|
||||
"or a Memory/MemoryScope/MemorySlice instance for custom configuration."
|
||||
),
|
||||
)
|
||||
embedder: EmbedderConfig | None = Field(
|
||||
default=None,
|
||||
@@ -371,31 +351,23 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
def _initialize_default_memories(self) -> None:
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory()
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory(
|
||||
crew=self,
|
||||
embedder_config=self.embedder,
|
||||
)
|
||||
self._entity_memory = self.entity_memory or EntityMemory(
|
||||
crew=self, embedder_config=self.embedder
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def create_crew_memory(self) -> Crew:
|
||||
"""Initialize private memory attributes."""
|
||||
self._external_memory = (
|
||||
# External memory does not support a default value since it was
|
||||
# designed to be managed entirely externally
|
||||
self.external_memory.set_crew(self) if self.external_memory else None
|
||||
)
|
||||
"""Initialize unified memory, respecting crew embedder config."""
|
||||
if self.memory is True:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
self._long_term_memory = self.long_term_memory
|
||||
self._short_term_memory = self.short_term_memory
|
||||
self._entity_memory = self.entity_memory
|
||||
embedder = None
|
||||
if self.embedder is not None:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
if self.memory:
|
||||
self._initialize_default_memories()
|
||||
embedder = build_embedder(self.embedder)
|
||||
self._memory = Memory(embedder=embedder)
|
||||
elif self.memory:
|
||||
# User passed a Memory / MemoryScope / MemorySlice instance
|
||||
self._memory = self.memory
|
||||
else:
|
||||
self._memory = None
|
||||
|
||||
return self
|
||||
|
||||
@@ -1664,10 +1636,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"_execution_span",
|
||||
"_file_handler",
|
||||
"_cache_handler",
|
||||
"_short_term_memory",
|
||||
"_long_term_memory",
|
||||
"_entity_memory",
|
||||
"_external_memory",
|
||||
"_memory",
|
||||
"agents",
|
||||
"tasks",
|
||||
"knowledge_sources",
|
||||
@@ -1701,18 +1670,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
if self.short_term_memory:
|
||||
copied_data["short_term_memory"] = self.short_term_memory.model_copy(
|
||||
deep=True
|
||||
)
|
||||
if self.long_term_memory:
|
||||
copied_data["long_term_memory"] = self.long_term_memory.model_copy(
|
||||
deep=True
|
||||
)
|
||||
if self.entity_memory:
|
||||
copied_data["entity_memory"] = self.entity_memory.model_copy(deep=True)
|
||||
if self.external_memory:
|
||||
copied_data["external_memory"] = self.external_memory.model_copy(deep=True)
|
||||
if getattr(self, "_memory", None):
|
||||
copied_data["memory"] = self._memory
|
||||
|
||||
copied_data.pop("agents", None)
|
||||
copied_data.pop("tasks", None)
|
||||
@@ -1843,23 +1802,24 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
Args:
|
||||
command_type: Type of memory to reset.
|
||||
Valid options: 'long', 'short', 'entity', 'knowledge', 'agent_knowledge'
|
||||
'kickoff_outputs', or 'all'
|
||||
Valid options: 'memory', 'knowledge', 'agent_knowledge',
|
||||
'kickoff_outputs', or 'all'. Legacy names 'long', 'short',
|
||||
'entity', 'external' are treated as 'memory'.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid command type is provided.
|
||||
RuntimeError: If memory reset operation fails.
|
||||
"""
|
||||
legacy_memory = frozenset(["long", "short", "entity", "external"])
|
||||
if command_type in legacy_memory:
|
||||
command_type = "memory"
|
||||
valid_types = frozenset(
|
||||
[
|
||||
"long",
|
||||
"short",
|
||||
"entity",
|
||||
"memory",
|
||||
"knowledge",
|
||||
"agent_knowledge",
|
||||
"kickoff_outputs",
|
||||
"all",
|
||||
"external",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1965,25 +1925,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
) + agent_knowledges
|
||||
|
||||
return {
|
||||
"short": {
|
||||
"system": getattr(self, "_short_term_memory", None),
|
||||
"memory": {
|
||||
"system": getattr(self, "_memory", None),
|
||||
"reset": default_reset,
|
||||
"name": "Short Term",
|
||||
},
|
||||
"entity": {
|
||||
"system": getattr(self, "_entity_memory", None),
|
||||
"reset": default_reset,
|
||||
"name": "Entity",
|
||||
},
|
||||
"external": {
|
||||
"system": getattr(self, "_external_memory", None),
|
||||
"reset": default_reset,
|
||||
"name": "External",
|
||||
},
|
||||
"long": {
|
||||
"system": getattr(self, "_long_term_memory", None),
|
||||
"reset": default_reset,
|
||||
"name": "Long Term",
|
||||
"name": "Memory",
|
||||
},
|
||||
"kickoff_outputs": {
|
||||
"system": getattr(self, "_task_output_handler", None),
|
||||
|
||||
@@ -1106,9 +1106,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
if self.state.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
|
||||
self._create_short_term_memory(formatted_answer)
|
||||
self._create_long_term_memory(formatted_answer)
|
||||
self._create_external_memory(formatted_answer)
|
||||
self._save_to_memory(formatted_answer)
|
||||
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
@@ -1191,9 +1189,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
if self.state.ask_for_human_input:
|
||||
formatted_answer = await self._ahandle_human_feedback(formatted_answer)
|
||||
|
||||
self._create_short_term_memory(formatted_answer)
|
||||
self._create_long_term_memory(formatted_answer)
|
||||
self._create_external_memory(formatted_answer)
|
||||
self._save_to_memory(formatted_answer)
|
||||
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
|
||||
@@ -416,13 +416,18 @@ def and_(*conditions: str | FlowCondition | Callable[..., Any]) -> FlowCondition
|
||||
return {"type": AND_CONDITION, "conditions": processed_conditions}
|
||||
|
||||
|
||||
class LockedListProxy(Generic[T]):
|
||||
class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
||||
"""Thread-safe proxy for list operations.
|
||||
|
||||
Wraps a list and uses a lock for all mutating operations.
|
||||
Subclasses ``list`` so that ``isinstance(proxy, list)`` returns True,
|
||||
which is required by libraries like LanceDB and Pydantic that do strict
|
||||
type checks. All mutations go through the lock; reads delegate to the
|
||||
underlying list.
|
||||
"""
|
||||
|
||||
def __init__(self, lst: list[T], lock: threading.Lock) -> None:
|
||||
# Do NOT call super().__init__() -- we don't want to copy data into
|
||||
# the builtin list storage. All access goes through self._list.
|
||||
self._list = lst
|
||||
self._lock = lock
|
||||
|
||||
@@ -477,13 +482,17 @@ class LockedListProxy(Generic[T]):
|
||||
return bool(self._list)
|
||||
|
||||
|
||||
class LockedDictProxy(Generic[T]):
|
||||
class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
||||
"""Thread-safe proxy for dict operations.
|
||||
|
||||
Wraps a dict and uses a lock for all mutating operations.
|
||||
Subclasses ``dict`` so that ``isinstance(proxy, dict)`` returns True,
|
||||
which is required by libraries like Pydantic that do strict type checks.
|
||||
All mutations go through the lock; reads delegate to the underlying dict.
|
||||
"""
|
||||
|
||||
def __init__(self, d: dict[str, T], lock: threading.Lock) -> None:
|
||||
# Do NOT call super().__init__() -- we don't want to copy data into
|
||||
# the builtin dict storage. All access goes through self._dict.
|
||||
self._dict = d
|
||||
self._lock = lock
|
||||
|
||||
@@ -700,6 +709,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
name: str | None = None
|
||||
tracing: bool | None = None
|
||||
stream: bool = False
|
||||
memory: Any = None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
|
||||
|
||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
@@ -767,6 +777,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
),
|
||||
)
|
||||
|
||||
# Auto-create memory if not provided at class or instance level
|
||||
if self.memory is None:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
self.memory = Memory()
|
||||
|
||||
# Register all flow-related methods
|
||||
for method_name in dir(self):
|
||||
if not method_name.startswith("_"):
|
||||
@@ -777,6 +793,56 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
method = method.__get__(self, self.__class__)
|
||||
self._methods[method.__name__] = method
|
||||
|
||||
def recall(self, query: str, **kwargs: Any) -> Any:
|
||||
"""Recall relevant memories. Delegates to this flow's memory.
|
||||
|
||||
Args:
|
||||
query: Natural language query.
|
||||
**kwargs: Passed to memory.recall (e.g. scope, categories, limit, depth).
|
||||
|
||||
Returns:
|
||||
Result of memory.recall(query, **kwargs).
|
||||
|
||||
Raises:
|
||||
ValueError: If no memory is configured for this flow.
|
||||
"""
|
||||
if self.memory is None:
|
||||
raise ValueError("No memory configured for this flow")
|
||||
return self.memory.recall(query, **kwargs)
|
||||
|
||||
def remember(self, content: str, **kwargs: Any) -> Any:
|
||||
"""Store content in memory. Delegates to this flow's memory.
|
||||
|
||||
Args:
|
||||
content: Text to remember.
|
||||
**kwargs: Passed to memory.remember (e.g. scope, categories, metadata).
|
||||
|
||||
Returns:
|
||||
Result of memory.remember(content, **kwargs).
|
||||
|
||||
Raises:
|
||||
ValueError: If no memory is configured for this flow.
|
||||
"""
|
||||
if self.memory is None:
|
||||
raise ValueError("No memory configured for this flow")
|
||||
return self.memory.remember(content, **kwargs)
|
||||
|
||||
def extract_memories(self, content: str) -> list[str]:
|
||||
"""Extract discrete memories from content. Delegates to this flow's memory.
|
||||
|
||||
Args:
|
||||
content: Raw text (e.g. task + result dump).
|
||||
|
||||
Returns:
|
||||
List of short, self-contained memory statements.
|
||||
|
||||
Raises:
|
||||
ValueError: If no memory is configured for this flow.
|
||||
"""
|
||||
if self.memory is None:
|
||||
raise ValueError("No memory configured for this flow")
|
||||
return self.memory.extract_memories(content)
|
||||
|
||||
def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool:
|
||||
"""Mark an OR listener as fired atomically.
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import time
|
||||
from functools import wraps
|
||||
import inspect
|
||||
import json
|
||||
@@ -48,6 +49,11 @@ from crewai.events.types.agent_events import (
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalFailedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.hooks.llm_hooks import get_after_llm_call_hooks, get_before_llm_call_hooks
|
||||
@@ -244,6 +250,10 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
description="A2A (Agent-to-Agent) configuration for delegating tasks to remote agents. "
|
||||
"Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of configurations.",
|
||||
)
|
||||
memory: bool | Any | None = Field(
|
||||
default=None,
|
||||
description="If True, use default Memory(). If Memory/MemoryScope/MemorySlice, use it for recall and remember.",
|
||||
)
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Results of the tools used by the agent."
|
||||
)
|
||||
@@ -266,6 +276,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
_after_llm_call_hooks: list[AfterLLMCallHookType] = PrivateAttr(
|
||||
default_factory=get_after_llm_call_hooks
|
||||
)
|
||||
_memory: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def emit_deprecation_warning(self) -> Self:
|
||||
@@ -363,6 +374,19 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def resolve_memory(self) -> Self:
|
||||
"""Resolve memory field to _memory: default Memory() when True, else user instance or None."""
|
||||
if self.memory is True:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
object.__setattr__(self, "_memory", Memory())
|
||||
elif self.memory is not None and self.memory is not False:
|
||||
object.__setattr__(self, "_memory", self.memory)
|
||||
else:
|
||||
object.__setattr__(self, "_memory", None)
|
||||
return self
|
||||
|
||||
@field_validator("guardrail", mode="before")
|
||||
@classmethod
|
||||
def validate_guardrail_function(
|
||||
@@ -474,6 +498,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
self._messages = self._format_messages(
|
||||
messages, response_format=response_format, input_files=input_files
|
||||
)
|
||||
self._inject_memory_context()
|
||||
|
||||
return self._execute_core(
|
||||
agent_info=agent_info, response_format=response_format
|
||||
@@ -496,6 +521,80 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise e
|
||||
|
||||
def _get_last_user_content(self) -> str:
|
||||
"""Get the last user message content from _messages for recall/input."""
|
||||
for msg in reversed(self._messages):
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content")
|
||||
return content if isinstance(content, str) else ""
|
||||
return ""
|
||||
|
||||
def _inject_memory_context(self) -> None:
|
||||
"""Recall relevant memories and append to the system message. No-op if _memory is None."""
|
||||
if self._memory is None:
|
||||
return
|
||||
query = self._get_last_user_content()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalStartedEvent(
|
||||
task_id=None,
|
||||
source_type="lite_agent",
|
||||
),
|
||||
)
|
||||
start_time = time.time()
|
||||
memory_block = ""
|
||||
try:
|
||||
matches = self._memory.recall(query, limit=10, depth="shallow")
|
||||
if matches:
|
||||
memory_block = "Relevant memories:\n" + "\n".join(
|
||||
f"- {m.record.content}" for m in matches
|
||||
)
|
||||
if memory_block:
|
||||
formatted = self.i18n.slice("memory").format(memory=memory_block)
|
||||
if self._messages and self._messages[0].get("role") == "system":
|
||||
self._messages[0]["content"] = (
|
||||
self._messages[0].get("content", "") + "\n\n" + formatted
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalCompletedEvent(
|
||||
task_id=None,
|
||||
memory_content=memory_block,
|
||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="lite_agent",
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalFailedEvent(
|
||||
task_id=None,
|
||||
source_type="lite_agent",
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
|
||||
def _save_to_memory(self, output_text: str) -> None:
|
||||
"""Extract discrete memories from the run and remember each. No-op if _memory is None."""
|
||||
if self._memory is None:
|
||||
return
|
||||
input_str = self._get_last_user_content() or "User request"
|
||||
try:
|
||||
raw = (
|
||||
f"Input: {input_str}\n"
|
||||
f"Agent: {self.role}\n"
|
||||
f"Result: {output_text}"
|
||||
)
|
||||
extracted = self._memory.extract_memories(raw)
|
||||
for mem in extracted:
|
||||
self._memory.remember(mem)
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
self._printer.print(
|
||||
content=f"Failed to save to memory: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
def _execute_core(
|
||||
self, agent_info: dict[str, Any], response_format: type[BaseModel] | None = None
|
||||
) -> LiteAgentOutput:
|
||||
@@ -511,6 +610,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
# Execute the agent using invoke loop
|
||||
agent_finish = self._invoke_loop()
|
||||
if self._memory is not None:
|
||||
self._save_to_memory(agent_finish.output)
|
||||
formatted_result: BaseModel | None = None
|
||||
|
||||
active_response_format = response_format or self.response_format
|
||||
|
||||
@@ -1,13 +1,24 @@
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
|
||||
"""Memory module: unified Memory with LLM analysis and pluggable storage."""
|
||||
|
||||
from crewai.memory.consolidation_flow import ConsolidationFlow
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.memory.types import (
|
||||
MemoryMatch,
|
||||
MemoryRecord,
|
||||
ScopeInfo,
|
||||
compute_composite_score,
|
||||
embed_text,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EntityMemory",
|
||||
"ExternalMemory",
|
||||
"LongTermMemory",
|
||||
"ShortTermMemory",
|
||||
"ConsolidationFlow",
|
||||
"Memory",
|
||||
"MemoryMatch",
|
||||
"MemoryRecord",
|
||||
"MemoryScope",
|
||||
"MemorySlice",
|
||||
"ScopeInfo",
|
||||
"compute_composite_score",
|
||||
"embed_text",
|
||||
]
|
||||
|
||||
378
lib/crewai/src/crewai/memory/analyze.py
Normal file
378
lib/crewai/src/crewai/memory/analyze.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""LLM-powered analysis for memory save and recall."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.memory.types import MemoryRecord, ScopeInfo
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtractedMetadata(BaseModel):
|
||||
"""Fixed schema for LLM-extracted metadata (OpenAI requires additionalProperties: false)."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
entities: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Entities (people, orgs, places) mentioned in the content.",
|
||||
)
|
||||
dates: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Dates or time references in the content.",
|
||||
)
|
||||
topics: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Topics or themes in the content.",
|
||||
)
|
||||
|
||||
|
||||
class MemoryAnalysis(BaseModel):
|
||||
"""LLM output for analyzing content before saving to memory."""
|
||||
|
||||
suggested_scope: str = Field(
|
||||
description="Best matching existing scope or new path (e.g. /company/decisions).",
|
||||
)
|
||||
categories: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Categories for the memory (prefer existing, add new if needed).",
|
||||
)
|
||||
importance: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance score from 0.0 to 1.0.",
|
||||
)
|
||||
extracted_metadata: ExtractedMetadata = Field(
|
||||
default_factory=ExtractedMetadata,
|
||||
description="Entities, dates, topics extracted from the content.",
|
||||
)
|
||||
|
||||
|
||||
class QueryAnalysis(BaseModel):
|
||||
"""LLM output for analyzing a recall query."""
|
||||
|
||||
keywords: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Key entities or keywords for filtering.",
|
||||
)
|
||||
time_hints: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Any time or recency hints in the query.",
|
||||
)
|
||||
suggested_scopes: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Scope paths to search (subset of available scopes).",
|
||||
)
|
||||
complexity: str = Field(
|
||||
default="simple",
|
||||
description="One of 'simple' (single fact) or 'complex' (aggregation/reasoning).",
|
||||
)
|
||||
|
||||
|
||||
class ExtractedMemories(BaseModel):
|
||||
"""LLM output for extracting discrete memories from raw content."""
|
||||
|
||||
memories: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of discrete, self-contained memory statements extracted from the content.",
|
||||
)
|
||||
|
||||
|
||||
class ConsolidationAction(BaseModel):
|
||||
"""A single action in a consolidation plan."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
action: str = Field(
|
||||
description="One of 'keep', 'update', or 'delete'.",
|
||||
)
|
||||
record_id: str = Field(
|
||||
description="ID of the existing record this action applies to.",
|
||||
)
|
||||
new_content: str | None = Field(
|
||||
default=None,
|
||||
description="Updated content text. Required when action is 'update'.",
|
||||
)
|
||||
reason: str = Field(
|
||||
default="",
|
||||
description="Brief reason for this action.",
|
||||
)
|
||||
|
||||
|
||||
class ConsolidationPlan(BaseModel):
|
||||
"""LLM output for consolidating new content with existing memories."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
actions: list[ConsolidationAction] = Field(
|
||||
default_factory=list,
|
||||
description="Actions to take on existing records (keep/update/delete).",
|
||||
)
|
||||
insert_new: bool = Field(
|
||||
default=True,
|
||||
description="Whether to also insert the new content as a separate record.",
|
||||
)
|
||||
insert_reason: str = Field(
|
||||
default="",
|
||||
description="Why the new content should or should not be inserted.",
|
||||
)
|
||||
|
||||
|
||||
def _get_prompt(key: str) -> str:
|
||||
"""Retrieve a memory prompt from the i18n translations.
|
||||
|
||||
Args:
|
||||
key: The prompt key under the "memory" section.
|
||||
|
||||
Returns:
|
||||
The prompt string.
|
||||
"""
|
||||
return get_i18n().memory(key)
|
||||
|
||||
|
||||
def extract_memories_from_content(content: str, llm: Any) -> list[str]:
|
||||
"""Use the LLM to extract discrete memory statements from raw content.
|
||||
|
||||
This is a pure helper: it does NOT store anything. Callers should call
|
||||
memory.remember() on each returned string to persist them.
|
||||
|
||||
On LLM failure, returns the full content as a single memory so callers
|
||||
still persist something rather than dropping the output.
|
||||
|
||||
Args:
|
||||
content: Raw text (e.g. task description + result dump).
|
||||
llm: The LLM instance to use.
|
||||
|
||||
Returns:
|
||||
List of short, self-contained memory statements (or [content] on failure).
|
||||
"""
|
||||
if not (content or "").strip():
|
||||
return []
|
||||
user = _get_prompt("extract_memories_user").format(content=content)
|
||||
messages = [
|
||||
{"role": "system", "content": _get_prompt("extract_memories_system")},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
try:
|
||||
if getattr(llm, "supports_function_calling", lambda: False)():
|
||||
response = llm.call(messages, response_model=ExtractedMemories)
|
||||
if isinstance(response, ExtractedMemories):
|
||||
return response.memories
|
||||
return ExtractedMemories.model_validate(response).memories
|
||||
response = llm.call(messages)
|
||||
if isinstance(response, ExtractedMemories):
|
||||
return response.memories
|
||||
if isinstance(response, str):
|
||||
data = json.loads(response)
|
||||
return ExtractedMemories.model_validate(data).memories
|
||||
return ExtractedMemories.model_validate(response).memories
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
"Memory extraction failed, storing full content as single memory: %s",
|
||||
e,
|
||||
exc_info=False,
|
||||
)
|
||||
return [content]
|
||||
|
||||
|
||||
async def aextract_memories_from_content(content: str, llm: Any) -> list[str]:
|
||||
"""Async variant of extract_memories_from_content."""
|
||||
return extract_memories_from_content(content, llm)
|
||||
|
||||
|
||||
def analyze_for_save(
|
||||
content: str,
|
||||
existing_scopes: list[str],
|
||||
existing_categories: list[str],
|
||||
llm: Any,
|
||||
) -> MemoryAnalysis:
|
||||
"""Use the LLM to infer scope, categories, importance, and metadata for a memory.
|
||||
|
||||
On LLM failure, returns safe defaults so remember() still persists the content.
|
||||
|
||||
Args:
|
||||
content: The memory content to analyze.
|
||||
existing_scopes: Current scope paths in the memory store.
|
||||
existing_categories: Current categories in use.
|
||||
llm: The LLM instance to use.
|
||||
|
||||
Returns:
|
||||
MemoryAnalysis with suggested_scope, categories, importance, extracted_metadata.
|
||||
"""
|
||||
user = _get_prompt("save_user").format(
|
||||
content=content,
|
||||
existing_scopes=existing_scopes or ["/"],
|
||||
existing_categories=existing_categories or [],
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": _get_prompt("save_system")},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
try:
|
||||
if getattr(llm, "supports_function_calling", lambda: False)():
|
||||
response = llm.call(messages, response_model=MemoryAnalysis)
|
||||
if isinstance(response, MemoryAnalysis):
|
||||
return response
|
||||
return MemoryAnalysis.model_validate(response)
|
||||
response = llm.call(messages)
|
||||
if isinstance(response, MemoryAnalysis):
|
||||
return response
|
||||
if isinstance(response, str):
|
||||
data = json.loads(response)
|
||||
return MemoryAnalysis.model_validate(data)
|
||||
return MemoryAnalysis.model_validate(response)
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
"Memory save analysis failed, using defaults (scope=/, importance=0.5): %s",
|
||||
e,
|
||||
exc_info=False,
|
||||
)
|
||||
return MemoryAnalysis(
|
||||
suggested_scope="/",
|
||||
categories=[],
|
||||
importance=0.5,
|
||||
extracted_metadata=ExtractedMetadata(),
|
||||
)
|
||||
|
||||
|
||||
async def aanalyze_for_save(
|
||||
content: str,
|
||||
existing_scopes: list[str],
|
||||
existing_categories: list[str],
|
||||
llm: Any,
|
||||
) -> MemoryAnalysis:
|
||||
"""Async variant of analyze_for_save."""
|
||||
# Fallback to sync if no acall
|
||||
return analyze_for_save(content, existing_scopes, existing_categories, llm)
|
||||
|
||||
|
||||
def analyze_query(
|
||||
query: str,
|
||||
available_scopes: list[str],
|
||||
scope_info: ScopeInfo | None,
|
||||
llm: Any,
|
||||
) -> QueryAnalysis:
|
||||
"""Use the LLM to analyze a recall query.
|
||||
|
||||
On LLM failure, returns safe defaults so recall degrades to plain vector search.
|
||||
|
||||
Args:
|
||||
query: The user's recall query.
|
||||
available_scopes: Scope paths that exist in the store.
|
||||
scope_info: Optional info about the current scope.
|
||||
llm: The LLM instance to use.
|
||||
|
||||
Returns:
|
||||
QueryAnalysis with keywords, time_hints, suggested_scopes, complexity.
|
||||
"""
|
||||
scope_desc = ""
|
||||
if scope_info:
|
||||
scope_desc = f"Current scope has {scope_info.record_count} records, categories: {scope_info.categories}"
|
||||
user = _get_prompt("query_user").format(
|
||||
query=query,
|
||||
available_scopes=available_scopes or ["/"],
|
||||
scope_desc=scope_desc,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": _get_prompt("query_system")},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
try:
|
||||
if getattr(llm, "supports_function_calling", lambda: False)():
|
||||
response = llm.call(messages, response_model=QueryAnalysis)
|
||||
if isinstance(response, QueryAnalysis):
|
||||
return response
|
||||
return QueryAnalysis.model_validate(response)
|
||||
response = llm.call(messages)
|
||||
if isinstance(response, QueryAnalysis):
|
||||
return response
|
||||
if isinstance(response, str):
|
||||
data = json.loads(response)
|
||||
return QueryAnalysis.model_validate(data)
|
||||
return QueryAnalysis.model_validate(response)
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
"Query analysis failed, using defaults (complexity=simple): %s",
|
||||
e,
|
||||
exc_info=False,
|
||||
)
|
||||
scopes = (available_scopes or ["/"])[:5]
|
||||
return QueryAnalysis(
|
||||
keywords=[],
|
||||
time_hints=[],
|
||||
suggested_scopes=scopes,
|
||||
complexity="simple",
|
||||
)
|
||||
|
||||
|
||||
def analyze_for_consolidation(
|
||||
new_content: str,
|
||||
existing_records: list[MemoryRecord],
|
||||
llm: Any,
|
||||
) -> ConsolidationPlan:
|
||||
"""Use the LLM to decide how to consolidate new content with existing memories.
|
||||
|
||||
On LLM failure, returns a safe default (insert_new=True, no actions) so save
|
||||
is never blocked.
|
||||
|
||||
Args:
|
||||
new_content: The new content to store.
|
||||
existing_records: Existing records that are semantically similar.
|
||||
llm: The LLM instance to use.
|
||||
|
||||
Returns:
|
||||
ConsolidationPlan with actions per record and whether to insert the new content.
|
||||
"""
|
||||
if not existing_records:
|
||||
return ConsolidationPlan(actions=[], insert_new=True)
|
||||
records_lines = []
|
||||
for r in existing_records:
|
||||
created = r.created_at.isoformat() if r.created_at else ""
|
||||
records_lines.append(
|
||||
f"- id={r.id} | scope={r.scope} | importance={r.importance:.2f} | created={created}\n content: {r.content[:200]}{'...' if len(r.content) > 200 else ''}"
|
||||
)
|
||||
user = _get_prompt("consolidation_user").format(
|
||||
new_content=new_content,
|
||||
records_summary="\n\n".join(records_lines),
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": _get_prompt("consolidation_system")},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
try:
|
||||
if getattr(llm, "supports_function_calling", lambda: False)():
|
||||
response = llm.call(messages, response_model=ConsolidationPlan)
|
||||
if isinstance(response, ConsolidationPlan):
|
||||
return response
|
||||
return ConsolidationPlan.model_validate(response)
|
||||
response = llm.call(messages)
|
||||
if isinstance(response, ConsolidationPlan):
|
||||
return response
|
||||
if isinstance(response, str):
|
||||
data = json.loads(response)
|
||||
return ConsolidationPlan.model_validate(data)
|
||||
return ConsolidationPlan.model_validate(response)
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
"Consolidation analysis failed, defaulting to insert only: %s",
|
||||
e,
|
||||
exc_info=False,
|
||||
)
|
||||
return ConsolidationPlan(actions=[], insert_new=True)
|
||||
|
||||
|
||||
async def aanalyze_query(
|
||||
query: str,
|
||||
available_scopes: list[str],
|
||||
scope_info: ScopeInfo | None,
|
||||
llm: Any,
|
||||
) -> QueryAnalysis:
|
||||
"""Async variant of analyze_query."""
|
||||
return analyze_query(query, available_scopes, scope_info, llm)
|
||||
164
lib/crewai/src/crewai/memory/consolidation_flow.py
Normal file
164
lib/crewai/src/crewai/memory/consolidation_flow.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Consolidation flow: decide insert/update/delete when new content is similar to existing memories."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.flow.flow import Flow, listen, router, start
|
||||
from crewai.memory.analyze import (
|
||||
ConsolidationPlan,
|
||||
analyze_for_consolidation,
|
||||
)
|
||||
from crewai.memory.types import MemoryConfig, MemoryRecord, embed_text
|
||||
|
||||
|
||||
class ConsolidationState(BaseModel):
|
||||
"""State for the consolidation flow."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
new_content: str = ""
|
||||
new_embedding: list[float] = Field(default_factory=list)
|
||||
scope: str = "/"
|
||||
categories: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
importance: float = 0.5
|
||||
similar_records: list[MemoryRecord] = Field(default_factory=list)
|
||||
top_similarity: float = 0.0
|
||||
plan: ConsolidationPlan | None = None
|
||||
result_record: MemoryRecord | None = None
|
||||
records_updated: int = 0
|
||||
records_deleted: int = 0
|
||||
|
||||
|
||||
class ConsolidationFlow(Flow[ConsolidationState]):
|
||||
"""Flow that gates and runs memory consolidation on remember()."""
|
||||
|
||||
initial_state = ConsolidationState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Any,
|
||||
llm: Any,
|
||||
embedder: Any,
|
||||
config: MemoryConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._storage = storage
|
||||
self._llm = llm
|
||||
self._embedder = embedder
|
||||
self._config = config or MemoryConfig()
|
||||
|
||||
@start()
|
||||
def find_similar(self) -> list[MemoryRecord]:
|
||||
"""Search for existing records similar to the new content (cheap pre-check)."""
|
||||
if not self.state.new_embedding:
|
||||
self.state.top_similarity = 0.0
|
||||
return []
|
||||
scope_prefix = self.state.scope if self.state.scope.strip("/") else None
|
||||
raw = self._storage.search(
|
||||
self.state.new_embedding,
|
||||
scope_prefix=scope_prefix,
|
||||
categories=None,
|
||||
limit=self._config.consolidation_limit,
|
||||
min_score=0.0,
|
||||
)
|
||||
records = [r for r, _ in raw]
|
||||
self.state.similar_records = records
|
||||
if raw:
|
||||
_, top_score = raw[0]
|
||||
self.state.top_similarity = float(top_score)
|
||||
else:
|
||||
self.state.top_similarity = 0.0
|
||||
return records
|
||||
|
||||
@router(find_similar)
|
||||
def check_threshold(self) -> str:
|
||||
"""Gate: only run LLM consolidation when similarity is above threshold."""
|
||||
if self.state.top_similarity < self._config.consolidation_threshold:
|
||||
return "insert_only"
|
||||
return "needs_consolidation"
|
||||
|
||||
@listen("needs_consolidation")
|
||||
def analyze_conflicts(self) -> ConsolidationPlan:
|
||||
"""Run LLM to decide keep/update/delete per record and whether to insert new."""
|
||||
plan = analyze_for_consolidation(
|
||||
self.state.new_content,
|
||||
self.state.similar_records,
|
||||
self._llm,
|
||||
)
|
||||
self.state.plan = plan
|
||||
return plan
|
||||
|
||||
@listen("insert_only")
|
||||
def insert_only_path(self) -> None:
|
||||
"""Fast path: no similar records above threshold; plan stays None."""
|
||||
return None
|
||||
|
||||
@listen(insert_only_path)
|
||||
@listen(analyze_conflicts)
|
||||
def execute_plan(self) -> MemoryRecord | None:
|
||||
"""Apply consolidation plan or insert new record; set state.result_record."""
|
||||
now = datetime.utcnow()
|
||||
plan = self.state.plan
|
||||
record_by_id = {r.id: r for r in self.state.similar_records}
|
||||
|
||||
if plan is not None:
|
||||
for action in plan.actions:
|
||||
if action.action == "delete":
|
||||
self._storage.delete(record_ids=[action.record_id])
|
||||
self.state.records_deleted += 1
|
||||
elif action.action == "update" and action.new_content:
|
||||
existing = record_by_id.get(action.record_id)
|
||||
if existing is not None:
|
||||
new_embedding = embed_text(
|
||||
self._embedder, action.new_content
|
||||
)
|
||||
updated = MemoryRecord(
|
||||
id=existing.id,
|
||||
content=action.new_content,
|
||||
scope=existing.scope,
|
||||
categories=existing.categories,
|
||||
metadata=existing.metadata,
|
||||
importance=existing.importance,
|
||||
created_at=existing.created_at,
|
||||
last_accessed=now,
|
||||
embedding=new_embedding if new_embedding else existing.embedding,
|
||||
)
|
||||
self._storage.update(updated)
|
||||
self.state.records_updated += 1
|
||||
record_by_id[action.record_id] = updated
|
||||
|
||||
if not plan.insert_new:
|
||||
first_updated = next(
|
||||
(
|
||||
record_by_id[a.record_id]
|
||||
for a in plan.actions
|
||||
if a.action == "update" and a.record_id in record_by_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if first_updated is not None:
|
||||
self.state.result_record = first_updated
|
||||
return first_updated
|
||||
if self.state.similar_records:
|
||||
self.state.result_record = self.state.similar_records[0]
|
||||
return self.state.similar_records[0]
|
||||
|
||||
new_record = MemoryRecord(
|
||||
content=self.state.new_content,
|
||||
scope=self.state.scope,
|
||||
categories=self.state.categories,
|
||||
metadata=self.state.metadata,
|
||||
importance=self.state.importance,
|
||||
embedding=self.state.new_embedding if self.state.new_embedding else None,
|
||||
)
|
||||
if not new_record.embedding:
|
||||
new_embedding = embed_text(self._embedder, self.state.new_content)
|
||||
new_record = new_record.model_copy(update={"embedding": new_embedding if new_embedding else None})
|
||||
self._storage.save([new_record])
|
||||
self.state.result_record = new_record
|
||||
return new_record
|
||||
@@ -1,254 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.memory import (
|
||||
EntityMemory,
|
||||
ExternalMemory,
|
||||
LongTermMemory,
|
||||
ShortTermMemory,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class ContextualMemory:
|
||||
"""Aggregates and retrieves context from multiple memory sources."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stm: ShortTermMemory,
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
exm: ExternalMemory,
|
||||
agent: Agent | None = None,
|
||||
task: Task | None = None,
|
||||
) -> None:
|
||||
self.stm = stm
|
||||
self.ltm = ltm
|
||||
self.em = em
|
||||
self.exm = exm
|
||||
self.agent = agent
|
||||
self.task = task
|
||||
|
||||
if self.stm is not None:
|
||||
self.stm.agent = self.agent
|
||||
self.stm.task = self.task
|
||||
if self.ltm is not None:
|
||||
self.ltm.agent = self.agent
|
||||
self.ltm.task = self.task
|
||||
if self.em is not None:
|
||||
self.em.agent = self.agent
|
||||
self.em.task = self.task
|
||||
if self.exm is not None:
|
||||
self.exm.agent = self.agent
|
||||
self.exm.task = self.task
|
||||
|
||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""Build contextual information for a task synchronously.
|
||||
|
||||
Args:
|
||||
task: The task to build context for.
|
||||
context: Additional context string.
|
||||
|
||||
Returns:
|
||||
Formatted context string from all memory sources.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
|
||||
if query == "":
|
||||
return ""
|
||||
|
||||
context_parts = [
|
||||
self._fetch_ltm_context(task.description),
|
||||
self._fetch_stm_context(query),
|
||||
self._fetch_entity_context(query),
|
||||
self._fetch_external_context(query),
|
||||
]
|
||||
return "\n".join(filter(None, context_parts))
|
||||
|
||||
async def abuild_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""Build contextual information for a task asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task to build context for.
|
||||
context: Additional context string.
|
||||
|
||||
Returns:
|
||||
Formatted context string from all memory sources.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
|
||||
if query == "":
|
||||
return ""
|
||||
|
||||
# Fetch all contexts concurrently
|
||||
results = await asyncio.gather(
|
||||
self._afetch_ltm_context(task.description),
|
||||
self._afetch_stm_context(query),
|
||||
self._afetch_entity_context(query),
|
||||
self._afetch_external_context(query),
|
||||
)
|
||||
|
||||
return "\n".join(filter(None, results))
|
||||
|
||||
def _fetch_stm_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
|
||||
if self.stm is None:
|
||||
return ""
|
||||
|
||||
stm_results = self.stm.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in stm_results]
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task: str) -> str | None:
|
||||
"""
|
||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
|
||||
if self.ltm is None:
|
||||
return ""
|
||||
|
||||
ltm_results = self.ltm.search(task, latest_n=2)
|
||||
if not ltm_results:
|
||||
return None
|
||||
|
||||
formatted_results = [
|
||||
suggestion
|
||||
for result in ltm_results
|
||||
for suggestion in result["metadata"]["suggestions"]
|
||||
]
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
if self.em is None:
|
||||
return ""
|
||||
|
||||
em_results = self.em.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in em_results]
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
def _fetch_external_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches and formats relevant information from External Memory.
|
||||
Args:
|
||||
query (str): The search query to find relevant information.
|
||||
Returns:
|
||||
str: Formatted information as bullet points, or an empty string if none found.
|
||||
"""
|
||||
if self.exm is None:
|
||||
return ""
|
||||
|
||||
external_memories = self.exm.search(query)
|
||||
|
||||
if not external_memories:
|
||||
return ""
|
||||
|
||||
formatted_memories = "\n".join(
|
||||
f"- {result['content']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
|
||||
async def _afetch_stm_context(self, query: str) -> str:
|
||||
"""Fetch recent relevant insights from STM asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted insights as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.stm is None:
|
||||
return ""
|
||||
|
||||
stm_results = await self.stm.asearch(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in stm_results]
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
async def _afetch_ltm_context(self, task: str) -> str | None:
|
||||
"""Fetch historical data from LTM asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
|
||||
Returns:
|
||||
Formatted historical data as bullet points, or None if none found.
|
||||
"""
|
||||
if self.ltm is None:
|
||||
return ""
|
||||
|
||||
ltm_results = await self.ltm.asearch(task, latest_n=2)
|
||||
if not ltm_results:
|
||||
return None
|
||||
|
||||
formatted_results = [
|
||||
suggestion
|
||||
for result in ltm_results
|
||||
for suggestion in result["metadata"]["suggestions"]
|
||||
]
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
async def _afetch_entity_context(self, query: str) -> str:
|
||||
"""Fetch relevant entity information asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted entity information as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.em is None:
|
||||
return ""
|
||||
|
||||
em_results = await self.em.asearch(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in em_results]
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
async def _afetch_external_context(self, query: str) -> str:
|
||||
"""Fetch relevant information from External Memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted information as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.exm is None:
|
||||
return ""
|
||||
|
||||
external_memories = await self.exm.asearch(query)
|
||||
|
||||
if not external_memories:
|
||||
return ""
|
||||
|
||||
formatted_memories = "\n".join(
|
||||
f"- {result['content']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
@@ -1,404 +0,0 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
"""
|
||||
EntityMemory class for managing structured information about entities
|
||||
and their relationships using SQLite storage.
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
) from e
|
||||
config = (
|
||||
embedder_config.get("config")
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
else RAGStorage(
|
||||
type="entities",
|
||||
allow_reset=True,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Saves one or more entity items into the SQLite storage.
|
||||
|
||||
Args:
|
||||
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||
metadata: Optional metadata dict (included for supertype compatibility but not used).
|
||||
|
||||
Notes:
|
||||
The metadata parameter is included to satisfy the supertype signature but is not
|
||||
used - entity metadata is extracted from the EntityMemoryItem objects themselves.
|
||||
"""
|
||||
|
||||
if not value:
|
||||
return
|
||||
|
||||
items = value if isinstance(value, list) else [value]
|
||||
is_batch = len(items) > 1
|
||||
|
||||
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
metadata=metadata,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
saved_count = 0
|
||||
errors = []
|
||||
|
||||
def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||
"""Save a single item and return success status."""
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super(EntityMemory, self).save(data, item.metadata)
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, f"{item.name}: {e!s}"
|
||||
|
||||
try:
|
||||
for item in items:
|
||||
success, error = save_single_item(item)
|
||||
if success:
|
||||
saved_count += 1
|
||||
else:
|
||||
errors.append(error)
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
metadata = {"entity_count": saved_count, "errors": errors}
|
||||
else:
|
||||
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||
metadata = items[0].metadata
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=emit_value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise Exception(
|
||||
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
fail_metadata = (
|
||||
{"entity_count": len(items), "saved": saved_count}
|
||||
if is_batch
|
||||
else items[0].metadata
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
metadata=fail_metadata,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search entity memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = super().search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save entity items asynchronously.
|
||||
|
||||
Args:
|
||||
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||
metadata: Optional metadata dict (not used, for signature compatibility).
|
||||
"""
|
||||
if not value:
|
||||
return
|
||||
|
||||
items = value if isinstance(value, list) else [value]
|
||||
is_batch = len(items) > 1
|
||||
|
||||
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
metadata=metadata,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
saved_count = 0
|
||||
errors: list[str | None] = []
|
||||
|
||||
async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||
"""Save a single item asynchronously."""
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
await super(EntityMemory, self).asave(data, item.metadata)
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, f"{item.name}: {e!s}"
|
||||
|
||||
try:
|
||||
for item in items:
|
||||
success, error = await save_single_item(item)
|
||||
if success:
|
||||
saved_count += 1
|
||||
else:
|
||||
errors.append(error)
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
metadata = {"entity_count": saved_count, "errors": errors}
|
||||
else:
|
||||
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||
metadata = items[0].metadata
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=emit_value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise Exception(
|
||||
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
fail_metadata = (
|
||||
{"entity_count": len(items), "saved": saved_count}
|
||||
if is_batch
|
||||
else items[0].metadata
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
metadata=fail_metadata,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search entity memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await super().asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the entity memory: {e}"
|
||||
) from e
|
||||
@@ -1,12 +0,0 @@
|
||||
class EntityMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
type: str,
|
||||
description: str,
|
||||
relationships: str,
|
||||
):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.description = description
|
||||
self.metadata = {"relationships": relationships}
|
||||
@@ -1,301 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
|
||||
class ExternalMemory(Memory):
|
||||
def __init__(self, storage: Storage | None = None, **data: Any):
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
@staticmethod
|
||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
|
||||
@staticmethod
|
||||
def external_supported_storages() -> dict[str, Any]:
|
||||
return {
|
||||
"mem0": ExternalMemory._configure_mem0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_storage(
|
||||
crew: Any, embedder_config: dict[str, Any] | ProviderSpec | None
|
||||
) -> Storage:
|
||||
if not embedder_config:
|
||||
raise ValueError("embedder_config is required")
|
||||
|
||||
if "provider" not in embedder_config:
|
||||
raise ValueError("embedder_config must include a 'provider' key")
|
||||
|
||||
provider = embedder_config["provider"]
|
||||
supported_storages = ExternalMemory.external_supported_storages()
|
||||
if provider not in supported_storages:
|
||||
raise ValueError(f"Provider {provider} not supported")
|
||||
|
||||
storage: Storage = supported_storages[provider](
|
||||
crew, embedder_config.get("config", {})
|
||||
)
|
||||
return storage
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Saves a value into the external storage."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ExternalMemoryItem(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
super().save(value=item.value, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search external memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = super().search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to external memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ExternalMemoryItem(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
await super().asave(value=item.value, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search external memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await super().asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
self.storage.reset()
|
||||
|
||||
def set_crew(self, crew: Any) -> ExternalMemory:
|
||||
super().set_crew(crew)
|
||||
|
||||
if not self.storage:
|
||||
self.storage = self.create_storage(crew, self.embedder_config) # type: ignore[arg-type]
|
||||
|
||||
return self
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ExternalMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
):
|
||||
self.value = value
|
||||
self.metadata = metadata
|
||||
self.agent = agent
|
||||
@@ -1,255 +0,0 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
class LongTermMemory(Memory):
|
||||
"""
|
||||
LongTermMemory class for managing cross runs data related to overall crew's
|
||||
execution and performance.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: LTMSQLiteStorage | None = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def search( # type: ignore[override]
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = self.storage.load(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=task,
|
||||
results=results,
|
||||
limit=latest_n,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results or []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
|
||||
"""Save an item to long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
item: The LongTermMemoryItem to save.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
await self.storage.asave(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch( # type: ignore[override]
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.aload(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=task,
|
||||
results=results,
|
||||
limit=latest_n,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results or []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset long-term memory."""
|
||||
self.storage.reset()
|
||||
@@ -1,19 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LongTermMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
agent: str,
|
||||
task: str,
|
||||
expected_output: str,
|
||||
datetime: str,
|
||||
quality: int | float | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.quality = quality
|
||||
self.datetime = datetime
|
||||
self.expected_output = expected_output
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
@@ -1,121 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
"""Base class for memory, supporting agent tags and generic metadata."""
|
||||
|
||||
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
|
||||
storage: Any
|
||||
_agent: Agent | None = None
|
||||
_task: Task | None = None
|
||||
|
||||
def __init__(self, storage: Any, **data: Any):
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
@property
|
||||
def task(self) -> Task | None:
|
||||
"""Get the current task associated with this memory."""
|
||||
return self._task
|
||||
|
||||
@task.setter
|
||||
def task(self, task: Task | None) -> None:
|
||||
"""Set the current task associated with this memory."""
|
||||
self._task = task
|
||||
|
||||
@property
|
||||
def agent(self) -> Agent | None:
|
||||
"""Get the current agent associated with this memory."""
|
||||
return self._agent
|
||||
|
||||
@agent.setter
|
||||
def agent(self, agent: Agent | None) -> None:
|
||||
"""Set the current agent associated with this memory."""
|
||||
self._agent = agent
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to memory.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
await self.storage.asave(value, metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
results: list[Any] = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
return results
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search memory for relevant entries asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
results: list[Any] = await self.storage.asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
return results
|
||||
|
||||
def set_crew(self, crew: Any) -> Memory:
|
||||
"""Set the crew for this memory instance."""
|
||||
self.crew = crew
|
||||
return self
|
||||
256
lib/crewai/src/crewai/memory/memory_scope.py
Normal file
256
lib/crewai/src/crewai/memory/memory_scope.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Scoped and sliced views over unified Memory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
from crewai.memory.types import (
|
||||
_RECALL_OVERSAMPLE_FACTOR,
|
||||
MemoryMatch,
|
||||
MemoryRecord,
|
||||
ScopeInfo,
|
||||
)
|
||||
|
||||
|
||||
class MemoryScope:
|
||||
"""View of Memory restricted to a root path. All operations are scoped under that path."""
|
||||
|
||||
def __init__(self, memory: Memory, root_path: str) -> None:
|
||||
"""Initialize scope.
|
||||
|
||||
Args:
|
||||
memory: The underlying Memory instance.
|
||||
root_path: Root path for this scope (e.g. /agent/1).
|
||||
"""
|
||||
self._memory = memory
|
||||
self._root = root_path.rstrip("/") or ""
|
||||
if self._root and not self._root.startswith("/"):
|
||||
self._root = "/" + self._root
|
||||
|
||||
def _scope_path(self, scope: str | None) -> str:
|
||||
if not scope or scope == "/":
|
||||
return self._root or "/"
|
||||
s = scope.rstrip("/")
|
||||
if not s.startswith("/"):
|
||||
s = "/" + s
|
||||
if not self._root:
|
||||
return s
|
||||
base = self._root.rstrip("/")
|
||||
return f"{base}{s}"
|
||||
|
||||
def remember(
|
||||
self,
|
||||
content: str,
|
||||
scope: str | None = "/",
|
||||
categories: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
importance: float | None = None,
|
||||
) -> MemoryRecord:
|
||||
"""Remember content; scope is relative to this scope's root."""
|
||||
path = self._scope_path(scope)
|
||||
return self._memory.remember(
|
||||
content,
|
||||
scope=path,
|
||||
categories=categories,
|
||||
metadata=metadata,
|
||||
importance=importance,
|
||||
)
|
||||
|
||||
def recall(
|
||||
self,
|
||||
query: str,
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
depth: str = "auto",
|
||||
) -> list[MemoryMatch]:
|
||||
"""Recall within this scope (root path and below)."""
|
||||
search_scope = self._scope_path(scope) if scope else (self._root or "/")
|
||||
return self._memory.recall(
|
||||
query,
|
||||
scope=search_scope,
|
||||
categories=categories,
|
||||
limit=limit,
|
||||
depth=depth,
|
||||
)
|
||||
|
||||
def extract_memories(self, content: str) -> list[str]:
|
||||
"""Extract discrete memories from content; delegates to underlying Memory."""
|
||||
return self._memory.extract_memories(content)
|
||||
|
||||
def forget(
|
||||
self,
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
older_than: datetime | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
record_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
"""Forget within this scope."""
|
||||
prefix = self._scope_path(scope) if scope else (self._root or "/")
|
||||
return self._memory.forget(
|
||||
scope=prefix,
|
||||
categories=categories,
|
||||
older_than=older_than,
|
||||
metadata_filter=metadata_filter,
|
||||
record_ids=record_ids,
|
||||
)
|
||||
|
||||
def list_scopes(self, path: str = "/") -> list[str]:
|
||||
"""List child scopes under path (relative to this scope's root)."""
|
||||
full = self._scope_path(path)
|
||||
return self._memory.list_scopes(full)
|
||||
|
||||
def info(self, path: str = "/") -> ScopeInfo:
|
||||
"""Info for path under this scope."""
|
||||
full = self._scope_path(path)
|
||||
return self._memory.info(full)
|
||||
|
||||
def tree(self, path: str = "/", max_depth: int = 3) -> str:
|
||||
"""Tree under path within this scope."""
|
||||
full = self._scope_path(path)
|
||||
return self._memory.tree(full, max_depth=max_depth)
|
||||
|
||||
def list_categories(self, path: str | None = None) -> dict[str, int]:
|
||||
"""Categories in this scope; path None means this scope root."""
|
||||
full = self._scope_path(path) if path else (self._root or "/")
|
||||
return self._memory.list_categories(full)
|
||||
|
||||
def reset(self, scope: str | None = None) -> None:
|
||||
"""Reset within this scope."""
|
||||
prefix = self._scope_path(scope) if scope else (self._root or "/")
|
||||
self._memory.reset(scope=prefix)
|
||||
|
||||
def subscope(self, path: str) -> MemoryScope:
|
||||
"""Return a narrower scope under this scope."""
|
||||
child = path.strip("/")
|
||||
if not child:
|
||||
return MemoryScope(self._memory, self._root or "/")
|
||||
base = self._root.rstrip("/") or ""
|
||||
new_root = f"{base}/{child}" if base else f"/{child}"
|
||||
return MemoryScope(self._memory, new_root)
|
||||
|
||||
|
||||
class MemorySlice:
|
||||
"""View over multiple scopes: recall searches all, remember requires explicit scope unless read_only."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory: Memory,
|
||||
scopes: list[str],
|
||||
categories: list[str] | None = None,
|
||||
read_only: bool = True,
|
||||
) -> None:
|
||||
"""Initialize slice.
|
||||
|
||||
Args:
|
||||
memory: The underlying Memory instance.
|
||||
scopes: List of scope paths to include.
|
||||
categories: Optional category filter for recall.
|
||||
read_only: If True, remember() raises PermissionError.
|
||||
"""
|
||||
self._memory = memory
|
||||
self._scopes = [s.rstrip("/") or "/" for s in scopes]
|
||||
self._categories = categories
|
||||
self._read_only = read_only
|
||||
|
||||
def remember(
|
||||
self,
|
||||
content: str,
|
||||
scope: str,
|
||||
categories: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
importance: float | None = None,
|
||||
) -> MemoryRecord:
|
||||
"""Remember into an explicit scope. Required when read_only=False."""
|
||||
if self._read_only:
|
||||
raise PermissionError("This MemorySlice is read-only")
|
||||
return self._memory.remember(
|
||||
content,
|
||||
scope=scope,
|
||||
categories=categories,
|
||||
metadata=metadata,
|
||||
importance=importance,
|
||||
)
|
||||
|
||||
def recall(
|
||||
self,
|
||||
query: str,
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
depth: str = "auto",
|
||||
) -> list[MemoryMatch]:
|
||||
"""Recall across all slice scopes; results merged and re-ranked."""
|
||||
cats = categories or self._categories
|
||||
all_matches: list[MemoryMatch] = []
|
||||
for sc in self._scopes:
|
||||
matches = self._memory.recall(
|
||||
query,
|
||||
scope=sc,
|
||||
categories=cats,
|
||||
limit=limit * _RECALL_OVERSAMPLE_FACTOR,
|
||||
depth=depth,
|
||||
)
|
||||
all_matches.extend(matches)
|
||||
seen_ids: set[str] = set()
|
||||
unique: list[MemoryMatch] = []
|
||||
for m in sorted(all_matches, key=lambda x: x.score, reverse=True):
|
||||
if m.record.id not in seen_ids:
|
||||
seen_ids.add(m.record.id)
|
||||
unique.append(m)
|
||||
if len(unique) >= limit:
|
||||
break
|
||||
return unique
|
||||
|
||||
def extract_memories(self, content: str) -> list[str]:
|
||||
"""Extract discrete memories from content; delegates to underlying Memory."""
|
||||
return self._memory.extract_memories(content)
|
||||
|
||||
def list_scopes(self, path: str = "/") -> list[str]:
|
||||
"""List scopes across all slice roots."""
|
||||
out: list[str] = []
|
||||
for sc in self._scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
out.extend(self._memory.list_scopes(full))
|
||||
return sorted(set(out))
|
||||
|
||||
def info(self, path: str = "/") -> ScopeInfo:
|
||||
"""Aggregate info across slice scopes (record counts summed)."""
|
||||
total_records = 0
|
||||
all_categories: set[str] = set()
|
||||
oldest: datetime | None = None
|
||||
newest: datetime | None = None
|
||||
children: list[str] = []
|
||||
for sc in self._scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
inf = self._memory.info(full)
|
||||
total_records += inf.record_count
|
||||
all_categories.update(inf.categories)
|
||||
if inf.oldest_record:
|
||||
oldest = inf.oldest_record if oldest is None else min(oldest, inf.oldest_record)
|
||||
if inf.newest_record:
|
||||
newest = inf.newest_record if newest is None else max(newest, inf.newest_record)
|
||||
children.extend(inf.child_scopes)
|
||||
return ScopeInfo(
|
||||
path=path,
|
||||
record_count=total_records,
|
||||
categories=sorted(all_categories),
|
||||
oldest_record=oldest,
|
||||
newest_record=newest,
|
||||
child_scopes=sorted(set(children)),
|
||||
)
|
||||
|
||||
def list_categories(self, path: str | None = None) -> dict[str, int]:
|
||||
"""Categories and counts across slice scopes."""
|
||||
counts: dict[str, int] = {}
|
||||
for sc in self._scopes:
|
||||
full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc
|
||||
for k, v in self._memory.list_categories(full).items():
|
||||
counts[k] = counts.get(k, 0) + v
|
||||
return counts
|
||||
190
lib/crewai/src/crewai/memory/recall_flow.py
Normal file
190
lib/crewai/src/crewai/memory/recall_flow.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""RLM-inspired intelligent recall flow for memory retrieval."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.flow.flow import Flow, listen, router, start
|
||||
from crewai.memory.analyze import QueryAnalysis, analyze_query
|
||||
from crewai.memory.types import (
|
||||
_RECALL_OVERSAMPLE_FACTOR,
|
||||
MemoryConfig,
|
||||
MemoryMatch,
|
||||
MemoryRecord,
|
||||
compute_composite_score,
|
||||
)
|
||||
|
||||
|
||||
class RecallState(BaseModel):
|
||||
"""State for the recall flow."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
query: str = ""
|
||||
scope: str | None = None
|
||||
categories: list[str] | None = None
|
||||
limit: int = 10
|
||||
query_embedding: list[float] = Field(default_factory=list)
|
||||
query_analysis: QueryAnalysis | None = None
|
||||
candidate_scopes: list[str] = Field(default_factory=list)
|
||||
chunk_findings: list[Any] = Field(default_factory=list)
|
||||
evidence_gaps: list[str] = Field(default_factory=list)
|
||||
confidence: float = 0.0
|
||||
final_results: list[MemoryMatch] = Field(default_factory=list)
|
||||
exploration_budget: int = 1
|
||||
|
||||
|
||||
class RecallFlow(Flow[RecallState]):
|
||||
"""RLM-inspired intelligent memory recall flow."""
|
||||
|
||||
initial_state = RecallState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Any,
|
||||
llm: Any,
|
||||
config: MemoryConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._storage = storage
|
||||
self._llm = llm
|
||||
self._config = config or MemoryConfig()
|
||||
|
||||
@start()
|
||||
def analyze_query_step(self) -> QueryAnalysis:
|
||||
self.state.exploration_budget = self._config.exploration_budget
|
||||
available = self._storage.list_scopes(self.state.scope or "/")
|
||||
if not available:
|
||||
available = ["/"]
|
||||
scope_info = self._storage.get_scope_info(self.state.scope or "/") if self.state.scope else None
|
||||
analysis = analyze_query(
|
||||
self.state.query,
|
||||
available,
|
||||
scope_info,
|
||||
self._llm,
|
||||
)
|
||||
self.state.query_analysis = analysis
|
||||
return analysis
|
||||
|
||||
@listen(analyze_query_step)
|
||||
def filter_and_chunk(self) -> list[str]:
|
||||
analysis = self.state.query_analysis
|
||||
scope_prefix = (self.state.scope or "/").rstrip("/") or "/"
|
||||
if analysis and analysis.suggested_scopes:
|
||||
candidates = [s for s in analysis.suggested_scopes if s]
|
||||
else:
|
||||
candidates = self._storage.list_scopes(scope_prefix)
|
||||
if not candidates:
|
||||
info = self._storage.get_scope_info(scope_prefix)
|
||||
if info.record_count > 0:
|
||||
candidates = [scope_prefix]
|
||||
else:
|
||||
candidates = [scope_prefix]
|
||||
self.state.candidate_scopes = candidates[:20]
|
||||
return self.state.candidate_scopes
|
||||
|
||||
@listen(filter_and_chunk)
|
||||
def search_chunks(self) -> list[Any]:
|
||||
findings = []
|
||||
for scope in self.state.candidate_scopes:
|
||||
results = self._storage.search(
|
||||
self.state.query_embedding,
|
||||
scope_prefix=scope,
|
||||
categories=self.state.categories,
|
||||
limit=self.state.limit * _RECALL_OVERSAMPLE_FACTOR,
|
||||
min_score=0.0,
|
||||
)
|
||||
if results:
|
||||
top_composite, _ = compute_composite_score(
|
||||
results[0][0], results[0][1], self._config
|
||||
)
|
||||
findings.append(
|
||||
{
|
||||
"scope": scope,
|
||||
"results": results,
|
||||
"top_score": top_composite,
|
||||
}
|
||||
)
|
||||
self.state.chunk_findings = findings
|
||||
if findings:
|
||||
self.state.confidence = max(f["top_score"] for f in findings)
|
||||
else:
|
||||
self.state.confidence = 0.0
|
||||
return findings
|
||||
|
||||
@router(search_chunks)
|
||||
def decide_depth(self) -> str:
|
||||
analysis = self.state.query_analysis
|
||||
if (
|
||||
analysis
|
||||
and analysis.complexity == "complex"
|
||||
and self.state.confidence < self._config.complex_query_threshold
|
||||
):
|
||||
if self.state.exploration_budget > 0:
|
||||
return "explore_deeper"
|
||||
if self.state.confidence >= self._config.confidence_threshold_high:
|
||||
return "synthesize"
|
||||
if (
|
||||
self.state.exploration_budget > 0
|
||||
and self.state.confidence < self._config.confidence_threshold_low
|
||||
):
|
||||
return "explore_deeper"
|
||||
return "synthesize"
|
||||
|
||||
@listen("explore_deeper")
|
||||
def recursive_exploration(self) -> list[Any]:
|
||||
enhanced = []
|
||||
for finding in self.state.chunk_findings:
|
||||
if not finding.get("results"):
|
||||
continue
|
||||
content_parts = [r[0].content for r in finding["results"][:5]]
|
||||
chunk_text = "\n---\n".join(content_parts)
|
||||
prompt = (
|
||||
f"Query: {self.state.query}\n\n"
|
||||
f"Relevant memory excerpts:\n{chunk_text}\n\n"
|
||||
"Extract the most relevant information for the query. "
|
||||
"If something is missing, say what's missing in one short line."
|
||||
)
|
||||
try:
|
||||
response = self._llm.call([{"role": "user", "content": prompt}])
|
||||
if isinstance(response, str) and "missing" in response.lower():
|
||||
self.state.evidence_gaps.append(response[:200])
|
||||
enhanced.append({"scope": finding["scope"], "extraction": response, "results": finding["results"]})
|
||||
except Exception:
|
||||
enhanced.append({"scope": finding["scope"], "extraction": "", "results": finding["results"]})
|
||||
self.state.chunk_findings = enhanced
|
||||
return enhanced
|
||||
|
||||
@listen("synthesize")
|
||||
@listen(recursive_exploration)
|
||||
def synthesize_results(self) -> list[MemoryMatch]:
|
||||
seen_ids: set[str] = set()
|
||||
matches: list[MemoryMatch] = []
|
||||
for finding in self.state.chunk_findings:
|
||||
if not isinstance(finding, dict):
|
||||
continue
|
||||
results = finding.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
continue
|
||||
for item in results:
|
||||
if isinstance(item, (list, tuple)) and len(item) >= 2:
|
||||
record, score = item[0], item[1]
|
||||
else:
|
||||
continue
|
||||
if isinstance(record, MemoryRecord) and record.id not in seen_ids:
|
||||
seen_ids.add(record.id)
|
||||
composite, reasons = compute_composite_score(
|
||||
record, float(score), self._config
|
||||
)
|
||||
matches.append(
|
||||
MemoryMatch(
|
||||
record=record,
|
||||
score=composite,
|
||||
match_reasons=reasons,
|
||||
)
|
||||
)
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
self.state.final_results = matches[: self.state.limit]
|
||||
return self.state.final_results
|
||||
@@ -1,318 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
"""
|
||||
ShortTermMemory class for managing transient data related to immediate tasks
|
||||
and interactions.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
) from e
|
||||
config = (
|
||||
embedder_config.get("config")
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
else RAGStorage(
|
||||
type="short_term",
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
path=path,
|
||||
)
|
||||
)
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ShortTermMemoryItem(
|
||||
data=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
if self._memory_provider == "mem0":
|
||||
item.data = (
|
||||
f"Remember the following insights from Agent run: {item.data}"
|
||||
)
|
||||
|
||||
super().save(value=item.data, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
# agent_role=agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search short-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return list(results)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to short-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ShortTermMemoryItem(
|
||||
data=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
if self._memory_provider == "mem0":
|
||||
item.data = (
|
||||
f"Remember the following insights from Agent run: {item.data}"
|
||||
)
|
||||
|
||||
await super().asave(value=item.data, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search short-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return list(results)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the short-term memory: {e}"
|
||||
) from e
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ShortTermMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
data: Any,
|
||||
agent: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
self.data = data
|
||||
self.agent = agent
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
150
lib/crewai/src/crewai/memory/storage/backend.py
Normal file
150
lib/crewai/src/crewai/memory/storage/backend.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Storage backend protocol for the unified memory system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from crewai.memory.types import MemoryRecord, ScopeInfo
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class StorageBackend(Protocol):
|
||||
"""Protocol for pluggable memory storage backends."""
|
||||
|
||||
def save(self, records: list[MemoryRecord]) -> None:
|
||||
"""Save memory records to storage.
|
||||
|
||||
Args:
|
||||
records: List of memory records to persist.
|
||||
"""
|
||||
...
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.0,
|
||||
) -> list[tuple[MemoryRecord, float]]:
|
||||
"""Search for memories by vector similarity with optional filters.
|
||||
|
||||
Args:
|
||||
query_embedding: Embedding vector for the query.
|
||||
scope_prefix: Optional scope path prefix to filter results.
|
||||
categories: Optional list of categories to filter by.
|
||||
metadata_filter: Optional metadata key-value filter.
|
||||
limit: Maximum number of results to return.
|
||||
min_score: Minimum similarity score threshold.
|
||||
|
||||
Returns:
|
||||
List of (MemoryRecord, score) tuples ordered by relevance.
|
||||
"""
|
||||
...
|
||||
|
||||
def delete(
|
||||
self,
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
record_ids: list[str] | None = None,
|
||||
older_than: datetime | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
) -> int:
|
||||
"""Delete memories matching the given criteria.
|
||||
|
||||
Args:
|
||||
scope_prefix: Optional scope path prefix.
|
||||
categories: Optional list of categories.
|
||||
record_ids: Optional list of record IDs to delete.
|
||||
older_than: Optional cutoff datetime (delete older records).
|
||||
metadata_filter: Optional metadata key-value filter.
|
||||
|
||||
Returns:
|
||||
Number of records deleted.
|
||||
"""
|
||||
...
|
||||
|
||||
def update(self, record: MemoryRecord) -> None:
|
||||
"""Update an existing record. Replaces the record with the same ID."""
|
||||
...
|
||||
|
||||
def get_scope_info(self, scope: str) -> ScopeInfo:
|
||||
"""Get information about a scope.
|
||||
|
||||
Args:
|
||||
scope: The scope path.
|
||||
|
||||
Returns:
|
||||
ScopeInfo with record count, categories, date range, child scopes.
|
||||
"""
|
||||
...
|
||||
|
||||
def list_scopes(self, parent: str = "/") -> list[str]:
|
||||
"""List immediate child scopes under a parent path.
|
||||
|
||||
Args:
|
||||
parent: Parent scope path (default root).
|
||||
|
||||
Returns:
|
||||
List of immediate child scope paths.
|
||||
"""
|
||||
...
|
||||
|
||||
def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]:
|
||||
"""List categories and their counts within a scope.
|
||||
|
||||
Args:
|
||||
scope_prefix: Optional scope to limit to (None = global).
|
||||
|
||||
Returns:
|
||||
Mapping of category name to record count.
|
||||
"""
|
||||
...
|
||||
|
||||
def count(self, scope_prefix: str | None = None) -> int:
|
||||
"""Count records in scope (and subscopes).
|
||||
|
||||
Args:
|
||||
scope_prefix: Optional scope path (None = all).
|
||||
|
||||
Returns:
|
||||
Number of records.
|
||||
"""
|
||||
...
|
||||
|
||||
def reset(self, scope_prefix: str | None = None) -> None:
|
||||
"""Reset (delete all) memories in scope.
|
||||
|
||||
Args:
|
||||
scope_prefix: Optional scope path (None = reset all).
|
||||
"""
|
||||
...
|
||||
|
||||
async def asave(self, records: list[MemoryRecord]) -> None:
|
||||
"""Save memory records asynchronously."""
|
||||
...
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.0,
|
||||
) -> list[tuple[MemoryRecord, float]]:
|
||||
"""Search for memories asynchronously."""
|
||||
...
|
||||
|
||||
async def adelete(
|
||||
self,
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
record_ids: list[str] | None = None,
|
||||
older_than: datetime | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
) -> int:
|
||||
"""Delete memories asynchronously."""
|
||||
...
|
||||
@@ -1,16 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Storage:
|
||||
"""Abstract base class defining the storage interface"""
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def search(
|
||||
self, query: str, limit: int, score_threshold: float
|
||||
) -> dict[str, Any] | list[Any]:
|
||||
return {}
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
357
lib/crewai/src/crewai/memory/storage/lancedb_storage.py
Normal file
357
lib/crewai/src/crewai/memory/storage/lancedb_storage.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""LanceDB storage backend for the unified memory system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import lancedb
|
||||
|
||||
from crewai.memory.types import MemoryRecord, ScopeInfo
|
||||
|
||||
|
||||
# Default embedding vector dimensionality (matches OpenAI text-embedding-3-small).
|
||||
# Used when creating new tables and for zero-vector placeholder scans.
|
||||
# Callers can override via the ``vector_dim`` constructor parameter.
|
||||
DEFAULT_VECTOR_DIM = 1536
|
||||
|
||||
# Safety cap on the number of rows returned by a single scan query.
|
||||
# Prevents unbounded memory use when scanning large tables for scope info,
|
||||
# listing, or deletion. Internal only -- not user-configurable.
|
||||
_SCAN_ROWS_LIMIT = 50_000
|
||||
|
||||
|
||||
class LanceDBStorage:
|
||||
"""LanceDB-backed storage for the unified memory system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path | None = None,
|
||||
table_name: str = "memories",
|
||||
vector_dim: int = DEFAULT_VECTOR_DIM,
|
||||
) -> None:
|
||||
"""Initialize LanceDB storage.
|
||||
|
||||
Args:
|
||||
path: Directory path for the LanceDB database. Defaults to
|
||||
``$CREWAI_STORAGE_DIR/memory`` if the env var is set,
|
||||
otherwise ``./.crewai/memory``.
|
||||
table_name: Name of the table for memory records.
|
||||
vector_dim: Dimensionality of the embedding vector.
|
||||
"""
|
||||
if path is None:
|
||||
storage_dir = os.environ.get("CREWAI_STORAGE_DIR")
|
||||
path = Path(storage_dir) / "memory" if storage_dir else Path("./.crewai/memory")
|
||||
self._path = Path(path)
|
||||
self._path.mkdir(parents=True, exist_ok=True)
|
||||
self._table_name = table_name
|
||||
self._vector_dim = vector_dim
|
||||
self._db = lancedb.connect(str(self._path))
|
||||
self._table = self._get_or_create_table()
|
||||
|
||||
def _get_or_create_table(self) -> lancedb.table.Table:
|
||||
try:
|
||||
return self._db.open_table(self._table_name)
|
||||
except Exception:
|
||||
# Create empty table with schema from first record or placeholder
|
||||
placeholder = [
|
||||
{
|
||||
"id": "__schema_placeholder__",
|
||||
"content": "",
|
||||
"scope": "/",
|
||||
"categories_str": "[]",
|
||||
"metadata_str": "{}",
|
||||
"importance": 0.5,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"last_accessed": datetime.utcnow().isoformat(),
|
||||
"vector": [0.0] * self._vector_dim,
|
||||
}
|
||||
]
|
||||
table = self._db.create_table(self._table_name, placeholder)
|
||||
# Remove placeholder row
|
||||
table.delete("id = '__schema_placeholder__'")
|
||||
return table
|
||||
|
||||
def _record_to_row(self, record: MemoryRecord) -> dict[str, Any]:
|
||||
return {
|
||||
"id": record.id,
|
||||
"content": record.content,
|
||||
"scope": record.scope,
|
||||
"categories_str": json.dumps(record.categories),
|
||||
"metadata_str": json.dumps(record.metadata),
|
||||
"importance": record.importance,
|
||||
"created_at": record.created_at.isoformat(),
|
||||
"last_accessed": record.last_accessed.isoformat(),
|
||||
"vector": record.embedding if record.embedding else [0.0] * self._vector_dim,
|
||||
}
|
||||
|
||||
def _row_to_record(self, row: dict[str, Any]) -> MemoryRecord:
|
||||
def _parse_dt(val: Any) -> datetime:
|
||||
if val is None:
|
||||
return datetime.utcnow()
|
||||
if isinstance(val, datetime):
|
||||
return val
|
||||
s = str(val)
|
||||
return datetime.fromisoformat(s.replace("Z", "+00:00"))
|
||||
|
||||
return MemoryRecord(
|
||||
id=str(row["id"]),
|
||||
content=str(row["content"]),
|
||||
scope=str(row["scope"]),
|
||||
categories=json.loads(row["categories_str"]) if row.get("categories_str") else [],
|
||||
metadata=json.loads(row["metadata_str"]) if row.get("metadata_str") else {},
|
||||
importance=float(row.get("importance", 0.5)),
|
||||
created_at=_parse_dt(row.get("created_at")),
|
||||
last_accessed=_parse_dt(row.get("last_accessed")),
|
||||
embedding=row.get("vector"),
|
||||
)
|
||||
|
||||
def save(self, records: list[MemoryRecord]) -> None:
|
||||
if not records:
|
||||
return
|
||||
rows = [self._record_to_row(r) for r in records]
|
||||
for r in rows:
|
||||
if r["vector"] is None or len(r["vector"]) != self._vector_dim:
|
||||
r["vector"] = [0.0] * self._vector_dim
|
||||
self._table.add(rows)
|
||||
|
||||
def update(self, record: MemoryRecord) -> None:
|
||||
"""Update a record by ID. Preserves created_at, updates last_accessed."""
|
||||
safe_id = str(record.id).replace("'", "''")
|
||||
self._table.delete(f"id = '{safe_id}'")
|
||||
row = self._record_to_row(record)
|
||||
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
|
||||
row["vector"] = [0.0] * self._vector_dim
|
||||
self._table.add([row])
|
||||
|
||||
def get_record(self, record_id: str) -> MemoryRecord | None:
|
||||
"""Return a record by ID, or None if not found. LanceDB-only."""
|
||||
safe_id = str(record_id).replace("'", "''")
|
||||
rows = self._table.search([0.0] * self._vector_dim).where(f"id = '{safe_id}'").limit(1).to_list()
|
||||
if not rows:
|
||||
return None
|
||||
return self._row_to_record(rows[0])
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.0,
|
||||
) -> list[tuple[MemoryRecord, float]]:
|
||||
query = self._table.search(query_embedding)
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
like_val = prefix + "%"
|
||||
query = query.where(f"scope LIKE '{like_val}'")
|
||||
results = query.limit(limit * 3 if (categories or metadata_filter) else limit).to_list()
|
||||
out: list[tuple[MemoryRecord, float]] = []
|
||||
for row in results:
|
||||
record = self._row_to_record(row)
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
continue
|
||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
||||
continue
|
||||
distance = row.get("_distance", 0.0)
|
||||
score = 1.0 / (1.0 + float(distance)) if distance is not None else 1.0
|
||||
if score >= min_score:
|
||||
out.append((record, score))
|
||||
if len(out) >= limit:
|
||||
break
|
||||
return out[:limit]
|
||||
|
||||
def delete(
|
||||
self,
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
record_ids: list[str] | None = None,
|
||||
older_than: datetime | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
) -> int:
|
||||
if record_ids and not (categories or metadata_filter):
|
||||
before = self._table.count_rows()
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
|
||||
self._table.delete(f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
if categories or metadata_filter:
|
||||
rows = self._scan_rows(scope_prefix)
|
||||
to_delete: list[str] = []
|
||||
for row in rows:
|
||||
record = self._row_to_record(row)
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
continue
|
||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
||||
continue
|
||||
if older_than and record.created_at >= older_than:
|
||||
continue
|
||||
to_delete.append(record.id)
|
||||
if not to_delete:
|
||||
return 0
|
||||
before = self._table.count_rows()
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
|
||||
self._table.delete(f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
conditions = []
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
if not prefix.startswith("/"):
|
||||
prefix = "/" + prefix
|
||||
conditions.append(f"scope LIKE '{prefix}%' OR scope = '/'")
|
||||
if older_than is not None:
|
||||
conditions.append(f"created_at < '{older_than.isoformat()}'")
|
||||
if not conditions:
|
||||
# Delete all rows (scope_prefix is "/" or None and no older_than)
|
||||
before = self._table.count_rows()
|
||||
self._table.delete("id != ''")
|
||||
return before - self._table.count_rows()
|
||||
where_expr = " AND ".join(conditions)
|
||||
before = self._table.count_rows()
|
||||
self._table.delete(where_expr)
|
||||
return before - self._table.count_rows()
|
||||
|
||||
def _scan_rows(self, scope_prefix: str | None = None, limit: int = _SCAN_ROWS_LIMIT) -> list[dict[str, Any]]:
|
||||
"""Scan rows optionally filtered by scope prefix."""
|
||||
q = self._table.search([0.0] * self._vector_dim)
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
|
||||
return q.limit(limit).to_list()
|
||||
|
||||
def list_records(
|
||||
self, scope_prefix: str | None = None, limit: int = 200
|
||||
) -> list[MemoryRecord]:
|
||||
"""List records in a scope, newest first. LanceDB-only (TUI use)."""
|
||||
rows = self._scan_rows(scope_prefix, limit=limit)
|
||||
records = [self._row_to_record(r) for r in rows]
|
||||
records.sort(key=lambda r: r.created_at, reverse=True)
|
||||
return records[:limit]
|
||||
|
||||
def get_scope_info(self, scope: str) -> ScopeInfo:
|
||||
scope = scope.rstrip("/") or "/"
|
||||
prefix = scope if scope != "/" else ""
|
||||
if prefix and not prefix.startswith("/"):
|
||||
prefix = "/" + prefix
|
||||
rows = self._scan_rows(prefix or None)
|
||||
if not rows:
|
||||
return ScopeInfo(
|
||||
path=scope or "/",
|
||||
record_count=0,
|
||||
categories=[],
|
||||
oldest_record=None,
|
||||
newest_record=None,
|
||||
child_scopes=[],
|
||||
)
|
||||
categories_set: set[str] = set()
|
||||
oldest: datetime | None = None
|
||||
newest: datetime | None = None
|
||||
child_prefix = (prefix + "/") if prefix else "/"
|
||||
children: set[str] = set()
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
if child_prefix and sc.startswith(child_prefix):
|
||||
rest = sc[len(child_prefix):]
|
||||
if "/" not in rest and rest:
|
||||
children.add(sc)
|
||||
try:
|
||||
cat_str = row.get("categories_str") or "[]"
|
||||
categories_set.update(json.loads(cat_str))
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
created = row.get("created_at")
|
||||
if created:
|
||||
dt = datetime.fromisoformat(str(created).replace("Z", "+00:00")) if isinstance(created, str) else created
|
||||
if isinstance(dt, datetime):
|
||||
if oldest is None or dt < oldest:
|
||||
oldest = dt
|
||||
if newest is None or dt > newest:
|
||||
newest = dt
|
||||
return ScopeInfo(
|
||||
path=scope or "/",
|
||||
record_count=len(rows),
|
||||
categories=sorted(categories_set),
|
||||
oldest_record=oldest,
|
||||
newest_record=newest,
|
||||
child_scopes=sorted(children),
|
||||
)
|
||||
|
||||
def list_scopes(self, parent: str = "/") -> list[str]:
|
||||
parent = parent.rstrip("/") or ""
|
||||
prefix = (parent + "/") if parent else "/"
|
||||
rows = self._scan_rows(prefix if prefix != "/" else None)
|
||||
children: set[str] = set()
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
|
||||
rest = sc[len(prefix):]
|
||||
if "/" not in rest and rest:
|
||||
children.add(prefix + rest)
|
||||
return sorted(children)
|
||||
|
||||
def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]:
|
||||
rows = self._scan_rows(scope_prefix)
|
||||
counts: dict[str, int] = {}
|
||||
for row in rows:
|
||||
cat_str = row.get("categories_str") or "[]"
|
||||
try:
|
||||
parsed = json.loads(cat_str)
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
for c in parsed:
|
||||
counts[c] = counts.get(c, 0) + 1
|
||||
return counts
|
||||
|
||||
def count(self, scope_prefix: str | None = None) -> int:
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
return self._table.count_rows()
|
||||
info = self.get_scope_info(scope_prefix)
|
||||
return info.record_count
|
||||
|
||||
def reset(self, scope_prefix: str | None = None) -> None:
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
self._db.drop_table(self._table_name)
|
||||
self._table = self._get_or_create_table()
|
||||
return
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
if prefix:
|
||||
self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uFFFF'")
|
||||
|
||||
async def asave(self, records: list[MemoryRecord]) -> None:
|
||||
self.save(records)
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.0,
|
||||
) -> list[tuple[MemoryRecord, float]]:
|
||||
return self.search(
|
||||
query_embedding,
|
||||
scope_prefix=scope_prefix,
|
||||
categories=categories,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
)
|
||||
|
||||
async def adelete(
|
||||
self,
|
||||
scope_prefix: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
record_ids: list[str] | None = None,
|
||||
older_than: datetime | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
) -> int:
|
||||
return self.delete(
|
||||
scope_prefix=scope_prefix,
|
||||
categories=categories,
|
||||
record_ids=record_ids,
|
||||
older_than=older_than,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
@@ -1,215 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class LTMSQLiteStorage:
|
||||
"""SQLite storage class for long-term memory data."""
|
||||
|
||||
def __init__(self, db_path: str | None = None, verbose: bool = True) -> None:
|
||||
"""Initialize the SQLite storage.
|
||||
|
||||
Args:
|
||||
db_path: Optional path to the database file.
|
||||
verbose: Whether to print error messages.
|
||||
"""
|
||||
if db_path is None:
|
||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||
self.db_path = db_path
|
||||
self._verbose = verbose
|
||||
self._printer: Printer = Printer()
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self) -> None:
|
||||
"""Initialize the SQLite database and create LTM table."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS long_term_memories (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_description TEXT,
|
||||
metadata TEXT,
|
||||
datetime TEXT,
|
||||
score REAL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
if self._verbose:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred during database initialization: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def save(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: dict[str, Any],
|
||||
datetime: str,
|
||||
score: int | float,
|
||||
) -> None:
|
||||
"""Saves data to the LTM table with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(task_description, json.dumps(metadata), datetime, score),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
if self._verbose:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def load(self, task_description: str, latest_n: int) -> list[dict[str, Any]] | None:
|
||||
"""Queries the LTM table by task description with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT metadata, datetime, score
|
||||
FROM long_term_memories
|
||||
WHERE task_description = ?
|
||||
ORDER BY datetime DESC, score ASC
|
||||
LIMIT {latest_n}
|
||||
""", # nosec # noqa: S608
|
||||
(task_description,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
if rows:
|
||||
return [
|
||||
{
|
||||
"metadata": json.loads(row[0]),
|
||||
"datetime": row[1],
|
||||
"score": row[2],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
except sqlite3.Error as e:
|
||||
if self._verbose:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the LTM table with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM long_term_memories")
|
||||
conn.commit()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
if self._verbose:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: dict[str, Any],
|
||||
datetime: str,
|
||||
score: int | float,
|
||||
) -> None:
|
||||
"""Save data to the LTM table asynchronously.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task.
|
||||
metadata: Metadata associated with the memory.
|
||||
datetime: Timestamp of the memory.
|
||||
score: Quality score of the memory.
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(task_description, json.dumps(metadata), datetime, score),
|
||||
)
|
||||
await conn.commit()
|
||||
except aiosqlite.Error as e:
|
||||
if self._verbose:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
async def aload(
|
||||
self, task_description: str, latest_n: int
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Query the LTM table by task description asynchronously.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries or None if error occurs.
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
cursor = await conn.execute(
|
||||
f"""
|
||||
SELECT metadata, datetime, score
|
||||
FROM long_term_memories
|
||||
WHERE task_description = ?
|
||||
ORDER BY datetime DESC, score ASC
|
||||
LIMIT {latest_n}
|
||||
""", # nosec # noqa: S608
|
||||
(task_description,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
if rows:
|
||||
return [
|
||||
{
|
||||
"metadata": json.loads(row[0]),
|
||||
"datetime": row[1],
|
||||
"score": row[2],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except aiosqlite.Error as e:
|
||||
if self._verbose:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the LTM table asynchronously."""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
await conn.execute("DELETE FROM long_term_memories")
|
||||
await conn.commit()
|
||||
except aiosqlite.Error as e:
|
||||
if self._verbose:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
@@ -1,230 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from mem0 import Memory, MemoryClient # type: ignore[import-untyped,import-not-found]
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.rag.chromadb.utils import _sanitize_collection_name
|
||||
|
||||
|
||||
MAX_AGENT_ID_LENGTH_MEM0 = 255
|
||||
|
||||
|
||||
class Mem0Storage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embedding and searching across entities using Mem0.
|
||||
"""
|
||||
|
||||
def __init__(self, type, crew=None, config=None):
|
||||
super().__init__()
|
||||
|
||||
self._validate_type(type)
|
||||
self.memory_type = type
|
||||
self.crew = crew
|
||||
self.config = config or {}
|
||||
|
||||
self._extract_config_values()
|
||||
self._initialize_memory()
|
||||
|
||||
def _validate_type(self, type):
|
||||
supported_types = {"short_term", "long_term", "entities", "external"}
|
||||
if type not in supported_types:
|
||||
raise ValueError(
|
||||
f"Invalid type '{type}' for Mem0Storage. "
|
||||
f"Must be one of: {', '.join(supported_types)}"
|
||||
)
|
||||
|
||||
def _extract_config_values(self):
|
||||
self.mem0_run_id = self.config.get("run_id")
|
||||
self.includes = self.config.get("includes")
|
||||
self.excludes = self.config.get("excludes")
|
||||
self.custom_categories = self.config.get("custom_categories")
|
||||
self.infer = self.config.get("infer", True)
|
||||
|
||||
def _initialize_memory(self):
|
||||
api_key = self.config.get("api_key") or os.getenv("MEM0_API_KEY")
|
||||
org_id = self.config.get("org_id")
|
||||
project_id = self.config.get("project_id")
|
||||
local_config = self.config.get("local_mem0_config")
|
||||
|
||||
if api_key:
|
||||
self.memory = (
|
||||
MemoryClient(api_key=api_key, org_id=org_id, project_id=project_id)
|
||||
if org_id and project_id
|
||||
else MemoryClient(api_key=api_key)
|
||||
)
|
||||
if self.custom_categories:
|
||||
self.memory.update_project(custom_categories=self.custom_categories)
|
||||
else:
|
||||
self.memory = (
|
||||
Memory.from_config(local_config)
|
||||
if local_config and len(local_config)
|
||||
else Memory()
|
||||
)
|
||||
|
||||
def _create_filter_for_search(self):
|
||||
"""
|
||||
Returns:
|
||||
dict: A filter dictionary containing AND conditions for querying data.
|
||||
- Includes user_id and agent_id if both are present.
|
||||
- Includes user_id if only user_id is present.
|
||||
- Includes agent_id if only agent_id is present.
|
||||
- Includes run_id if memory_type is 'short_term' and
|
||||
mem0_run_id is present.
|
||||
"""
|
||||
filter = defaultdict(list)
|
||||
|
||||
if self.memory_type == "short_term" and self.mem0_run_id:
|
||||
filter["AND"].append({"run_id": self.mem0_run_id})
|
||||
else:
|
||||
user_id = self.config.get("user_id", "")
|
||||
agent_id = self.config.get("agent_id", "")
|
||||
|
||||
if user_id and agent_id:
|
||||
filter["OR"].append({"user_id": user_id})
|
||||
filter["OR"].append({"agent_id": agent_id})
|
||||
elif user_id:
|
||||
filter["AND"].append({"user_id": user_id})
|
||||
elif agent_id:
|
||||
filter["AND"].append({"agent_id": agent_id})
|
||||
|
||||
return filter
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str:
|
||||
return next(
|
||||
(
|
||||
m.get("content", "")
|
||||
for m in reversed(list(messages))
|
||||
if m.get("role") == role
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
conversations = []
|
||||
messages = metadata.pop("messages", None)
|
||||
if messages:
|
||||
last_user = _last_content(messages, "user")
|
||||
last_assistant = _last_content(messages, "assistant")
|
||||
|
||||
if user_msg := self._get_user_message(last_user):
|
||||
conversations.append({"role": "user", "content": user_msg})
|
||||
|
||||
if assistant_msg := self._get_assistant_message(last_assistant):
|
||||
conversations.append({"role": "assistant", "content": assistant_msg})
|
||||
else:
|
||||
conversations.append({"role": "assistant", "content": value})
|
||||
|
||||
user_id = self.config.get("user_id", "")
|
||||
|
||||
base_metadata = {
|
||||
"short_term": "short_term",
|
||||
"long_term": "long_term",
|
||||
"entities": "entity",
|
||||
"external": "external",
|
||||
}
|
||||
|
||||
# Shared base params
|
||||
params: dict[str, Any] = {
|
||||
"metadata": {"type": base_metadata[self.memory_type], **metadata},
|
||||
"infer": self.infer,
|
||||
}
|
||||
|
||||
# MemoryClient-specific overrides
|
||||
if isinstance(self.memory, MemoryClient):
|
||||
params["includes"] = self.includes
|
||||
params["excludes"] = self.excludes
|
||||
params["output_format"] = "v1.1"
|
||||
params["version"] = "v2"
|
||||
|
||||
if self.memory_type == "short_term" and self.mem0_run_id:
|
||||
params["run_id"] = self.mem0_run_id
|
||||
|
||||
if user_id:
|
||||
params["user_id"] = user_id
|
||||
|
||||
if agent_id := self.config.get("agent_id", self._get_agent_name()):
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
self.memory.add(conversations, **params)
|
||||
|
||||
def search(
|
||||
self, query: str, limit: int = 5, score_threshold: float = 0.6
|
||||
) -> list[Any]:
|
||||
params = {
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"version": "v2",
|
||||
"output_format": "v1.1",
|
||||
}
|
||||
|
||||
if user_id := self.config.get("user_id", ""):
|
||||
params["user_id"] = user_id
|
||||
|
||||
memory_type_map = {
|
||||
"short_term": {"type": "short_term"},
|
||||
"long_term": {"type": "long_term"},
|
||||
"entities": {"type": "entity"},
|
||||
"external": {"type": "external"},
|
||||
}
|
||||
|
||||
if self.memory_type in memory_type_map:
|
||||
params["metadata"] = memory_type_map[self.memory_type]
|
||||
if self.memory_type == "short_term":
|
||||
params["run_id"] = self.mem0_run_id
|
||||
|
||||
# Discard the filters for now since we create the filters
|
||||
# automatically when the crew is created.
|
||||
|
||||
params["filters"] = self._create_filter_for_search()
|
||||
params["threshold"] = score_threshold
|
||||
|
||||
if isinstance(self.memory, Memory):
|
||||
del params["metadata"], params["version"], params["output_format"]
|
||||
if params.get("run_id"):
|
||||
del params["run_id"]
|
||||
|
||||
results = self.memory.search(**params)
|
||||
|
||||
# This makes it compatible for Contextual Memory to retrieve
|
||||
for result in results["results"]:
|
||||
result["content"] = result["memory"]
|
||||
|
||||
return [r for r in results["results"]]
|
||||
|
||||
def reset(self):
|
||||
if self.memory:
|
||||
self.memory.reset()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def _get_agent_name(self) -> str:
|
||||
if not self.crew:
|
||||
return ""
|
||||
|
||||
agents = self.crew.agents
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return _sanitize_collection_name(
|
||||
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
|
||||
)
|
||||
|
||||
def _get_assistant_message(self, text: str) -> str:
|
||||
marker = "Final Answer:"
|
||||
if marker in text:
|
||||
return text.split(marker, 1)[1].strip()
|
||||
return text
|
||||
|
||||
def _get_user_message(self, text: str) -> str:
|
||||
pattern = r"User message:\s*(.*)"
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return text
|
||||
@@ -1,315 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
import warnings
|
||||
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.types import BaseRecord
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider[Any] | None = None,
|
||||
crew: Crew | None = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
crew_agents = crew.agents if crew else []
|
||||
sanitized_roles = [self._sanitize_role(agent.role) for agent in crew_agents]
|
||||
agents_str = "_".join(sanitized_roles)
|
||||
self.agents = agents_str
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents_str)
|
||||
|
||||
self.type = type
|
||||
self._client: BaseClient | None = None
|
||||
|
||||
self.allow_reset = allow_reset
|
||||
self.path = path
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
if self.embedder_config:
|
||||
embedding_function = build_embedder(self.embedder_config)
|
||||
|
||||
try:
|
||||
_ = embedding_function(["test"])
|
||||
except Exception as e:
|
||||
provider = (
|
||||
self.embedder_config["provider"]
|
||||
if isinstance(self.embedder_config, dict)
|
||||
else self.embedder_config.__class__.__name__.replace(
|
||||
"Provider", ""
|
||||
).lower()
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize embedder. Please check your configuration or connection.\n"
|
||||
f"Provider: {provider}\n"
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
),
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
|
||||
if self.path:
|
||||
config.settings.persist_directory = self.path
|
||||
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
"""Get the appropriate client - instance-specific or global."""
|
||||
return self._client if self._client else get_rag_client()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
@staticmethod
|
||||
def _build_storage_file_name(type: str, file_name: str) -> str:
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
base_path = f"{db_storage_path()}/{type}"
|
||||
|
||||
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||
logging.warning(
|
||||
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
||||
)
|
||||
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value to storage.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Metadata to associate with the value.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
client.get_or_create_collection(collection_name=collection_name)
|
||||
|
||||
document: BaseRecord = {"content": value}
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
self.embedder_config
|
||||
and isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
client.add_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=[document]
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
async def asave(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value to storage asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Metadata to associate with the value.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
await client.aget_or_create_collection(collection_name=collection_name)
|
||||
|
||||
document: BaseRecord = {"content": value}
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
self.embedder_config
|
||||
and isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name, documents=[document]
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for matching entries in storage.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
filter: Optional metadata filter.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching entries.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
return client.search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
limit=limit,
|
||||
metadata_filter=filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for matching entries in storage asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
filter: Optional metadata filter.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching entries.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
return await client.asearch(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
limit=limit,
|
||||
metadata_filter=filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
client.delete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
if "attempt to write a readonly database" in str(
|
||||
e
|
||||
) or "does not exist" in str(e):
|
||||
# Ignore readonly database and collection not found errors (already reset)
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
) from e
|
||||
285
lib/crewai/src/crewai/memory/types.py
Normal file
285
lib/crewai/src/crewai/memory/types.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""Data types for the unified memory system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# When searching the vector store, we ask for more results than the caller
|
||||
# requested so that post-search steps (composite scoring, deduplication,
|
||||
# category filtering) have enough candidates to fill the final result set.
|
||||
# For example, if the caller asks for 10 results and this is 2, we fetch 20
|
||||
# from the vector store and then trim down after scoring.
|
||||
_RECALL_OVERSAMPLE_FACTOR = 2
|
||||
|
||||
|
||||
class MemoryRecord(BaseModel):
|
||||
"""A single memory entry stored in the memory system."""
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid4()),
|
||||
description="Unique identifier for the memory record.",
|
||||
)
|
||||
content: str = Field(description="The textual content of the memory.")
|
||||
scope: str = Field(
|
||||
default="/",
|
||||
description="Hierarchical path organizing the memory (e.g. /company/team/user).",
|
||||
)
|
||||
categories: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Categories or tags for the memory.",
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Arbitrary metadata associated with the memory.",
|
||||
)
|
||||
importance: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance score from 0.0 to 1.0, affects retrieval ranking.",
|
||||
)
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow,
|
||||
description="When the memory was created.",
|
||||
)
|
||||
last_accessed: datetime = Field(
|
||||
default_factory=datetime.utcnow,
|
||||
description="When the memory was last accessed.",
|
||||
)
|
||||
embedding: list[float] | None = Field(
|
||||
default=None,
|
||||
description="Vector embedding for semantic search. Computed on save if not provided.",
|
||||
)
|
||||
|
||||
|
||||
class MemoryMatch(BaseModel):
|
||||
"""A memory record with relevance score from a recall operation."""
|
||||
|
||||
record: MemoryRecord = Field(description="The matched memory record.")
|
||||
score: float = Field(
|
||||
description="Combined relevance score (semantic, recency, importance).",
|
||||
)
|
||||
match_reasons: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Reasons for the match (e.g. semantic, recency, importance).",
|
||||
)
|
||||
|
||||
|
||||
class ScopeInfo(BaseModel):
|
||||
"""Information about a scope in the memory hierarchy."""
|
||||
|
||||
path: str = Field(description="The scope path (e.g. /company/engineering).")
|
||||
record_count: int = Field(
|
||||
default=0,
|
||||
description="Number of records in this scope (including subscopes if applicable).",
|
||||
)
|
||||
categories: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Categories used in this scope.",
|
||||
)
|
||||
oldest_record: datetime | None = Field(
|
||||
default=None,
|
||||
description="Timestamp of the oldest record in this scope.",
|
||||
)
|
||||
newest_record: datetime | None = Field(
|
||||
default=None,
|
||||
description="Timestamp of the newest record in this scope.",
|
||||
)
|
||||
child_scopes: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Immediate child scope paths.",
|
||||
)
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""Internal configuration for memory scoring, consolidation, and recall behavior.
|
||||
|
||||
Users configure these values via ``Memory(...)`` keyword arguments.
|
||||
This model is not part of the public API -- it exists so that the config
|
||||
can be passed as a single object to RecallFlow, ConsolidationFlow, and
|
||||
compute_composite_score.
|
||||
"""
|
||||
|
||||
# -- Composite score weights --
|
||||
# The recall composite score is:
|
||||
# semantic_weight * similarity + recency_weight * decay + importance_weight * importance
|
||||
# These should sum to ~1.0 for intuitive 0-1 scoring.
|
||||
|
||||
recency_weight: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description=(
|
||||
"Weight for recency in the composite relevance score. "
|
||||
"Higher values favor recently created memories over older ones."
|
||||
),
|
||||
)
|
||||
semantic_weight: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description=(
|
||||
"Weight for semantic similarity in the composite relevance score. "
|
||||
"Higher values make recall rely more on vector-search closeness."
|
||||
),
|
||||
)
|
||||
importance_weight: float = Field(
|
||||
default=0.2,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description=(
|
||||
"Weight for explicit importance in the composite relevance score. "
|
||||
"Higher values make high-importance memories surface more often."
|
||||
),
|
||||
)
|
||||
recency_half_life_days: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
description=(
|
||||
"Number of days for the recency score to halve (exponential decay). "
|
||||
"Lower values make memories lose relevance faster; higher values "
|
||||
"keep old memories relevant longer."
|
||||
),
|
||||
)
|
||||
|
||||
# -- Consolidation (on save) --
|
||||
|
||||
consolidation_threshold: float = Field(
|
||||
default=0.85,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description=(
|
||||
"Semantic similarity above which the consolidation flow is triggered "
|
||||
"when saving new content. The LLM then decides whether to merge, "
|
||||
"update, or delete overlapping records. Set to 1.0 to disable."
|
||||
),
|
||||
)
|
||||
consolidation_limit: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
description=(
|
||||
"Maximum number of existing records to compare against when checking "
|
||||
"for consolidation during a save."
|
||||
),
|
||||
)
|
||||
|
||||
# -- Save defaults --
|
||||
|
||||
default_importance: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description=(
|
||||
"Importance assigned to new memories when no explicit value is given "
|
||||
"and the LLM analysis path is skipped (i.e. all fields provided by "
|
||||
"the caller)."
|
||||
),
|
||||
)
|
||||
|
||||
# -- Recall depth control --
|
||||
# The RecallFlow router uses these thresholds to decide between returning
|
||||
# results immediately ("synthesize") and doing an extra LLM-driven
|
||||
# exploration round ("explore_deeper").
|
||||
|
||||
confidence_threshold_high: float = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description=(
|
||||
"When recall confidence is at or above this value, results are "
|
||||
"returned directly without deeper exploration."
|
||||
),
|
||||
)
|
||||
confidence_threshold_low: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description=(
|
||||
"When recall confidence is below this value and exploration budget "
|
||||
"remains, a deeper LLM-driven exploration round is triggered."
|
||||
),
|
||||
)
|
||||
complex_query_threshold: float = Field(
|
||||
default=0.7,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description=(
|
||||
"For queries classified as 'complex' by the LLM, deeper exploration "
|
||||
"is triggered when confidence is below this value."
|
||||
),
|
||||
)
|
||||
exploration_budget: int = Field(
|
||||
default=1,
|
||||
ge=0,
|
||||
description=(
|
||||
"Number of LLM-driven exploration rounds allowed during deep recall. "
|
||||
"0 means recall always uses direct vector search only; higher values "
|
||||
"allow more thorough but slower retrieval."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def embed_text(embedder: Any, text: str) -> list[float]:
|
||||
"""Embed a single text string and return a list of floats.
|
||||
|
||||
Args:
|
||||
embedder: Callable that accepts a list of strings and returns embeddings.
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
List of floats representing the embedding, or empty list on failure.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
result = embedder([text])
|
||||
if not result:
|
||||
return []
|
||||
first = result[0]
|
||||
if hasattr(first, "tolist"):
|
||||
return first.tolist()
|
||||
if isinstance(first, list):
|
||||
return [float(x) for x in first]
|
||||
return list(first)
|
||||
|
||||
|
||||
def compute_composite_score(
|
||||
record: MemoryRecord,
|
||||
semantic_score: float,
|
||||
config: MemoryConfig,
|
||||
) -> tuple[float, list[str]]:
|
||||
"""Compute a weighted composite relevance score from semantic, recency, and importance.
|
||||
|
||||
composite = w_semantic * semantic + w_recency * decay + w_importance * importance
|
||||
where decay = 0.5^(age_days / half_life_days).
|
||||
|
||||
Args:
|
||||
record: The memory record (provides created_at and importance).
|
||||
semantic_score: Raw semantic similarity from vector search, in [0, 1].
|
||||
config: Weights and recency half-life.
|
||||
|
||||
Returns:
|
||||
Tuple of (composite_score, match_reasons). match_reasons includes
|
||||
"semantic" always; "recency" if decay > 0.5; "importance" if record.importance > 0.5.
|
||||
"""
|
||||
age_seconds = (datetime.utcnow() - record.created_at).total_seconds()
|
||||
age_days = max(age_seconds / 86400.0, 0.0)
|
||||
decay = 0.5 ** (age_days / config.recency_half_life_days)
|
||||
|
||||
composite = (
|
||||
config.semantic_weight * semantic_score
|
||||
+ config.recency_weight * decay
|
||||
+ config.importance_weight * record.importance
|
||||
)
|
||||
|
||||
reasons: list[str] = ["semantic"]
|
||||
if decay > 0.5:
|
||||
reasons.append("recency")
|
||||
if record.importance > 0.5:
|
||||
reasons.append("importance")
|
||||
|
||||
return composite, reasons
|
||||
543
lib/crewai/src/crewai/memory/unified_memory.py
Normal file
543
lib/crewai/src/crewai/memory/unified_memory.py
Normal file
@@ -0,0 +1,543 @@
|
||||
"""Unified Memory class: single intelligent memory with LLM analysis and pluggable storage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.analyze import (
|
||||
MemoryAnalysis,
|
||||
analyze_for_save,
|
||||
extract_memories_from_content,
|
||||
)
|
||||
from crewai.memory.recall_flow import RecallFlow
|
||||
from crewai.memory.storage.backend import StorageBackend
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.types import (
|
||||
MemoryConfig,
|
||||
MemoryMatch,
|
||||
MemoryRecord,
|
||||
ScopeInfo,
|
||||
compute_composite_score,
|
||||
embed_text,
|
||||
)
|
||||
|
||||
|
||||
def _default_embedder() -> Any:
|
||||
"""Build default OpenAI embedder for memory."""
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
return build_embedder({"provider": "openai", "config": {}})
|
||||
|
||||
|
||||
class Memory:
|
||||
"""Unified memory: standalone, LLM-analyzed, with intelligent recall flow.
|
||||
|
||||
Works without agent/crew. Uses LLM to infer scope, categories, importance on save.
|
||||
Uses RecallFlow for adaptive-depth recall. Supports scope/slice views and
|
||||
pluggable storage (LanceDB default).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM | str = "gpt-4o-mini",
|
||||
storage: StorageBackend | str = "lancedb",
|
||||
embedder: Any = None,
|
||||
# -- Scoring weights --
|
||||
# These three weights control how recall results are ranked.
|
||||
# The composite score is: semantic_weight * similarity + recency_weight * decay + importance_weight * importance.
|
||||
# They should sum to ~1.0 for intuitive scoring.
|
||||
recency_weight: float = 0.3,
|
||||
semantic_weight: float = 0.5,
|
||||
importance_weight: float = 0.2,
|
||||
# How quickly old memories lose relevance. The recency score halves every
|
||||
# N days (exponential decay). Lower = faster forgetting; higher = longer relevance.
|
||||
recency_half_life_days: int = 30,
|
||||
# -- Consolidation --
|
||||
# When remembering new content, if an existing record has similarity >= this
|
||||
# threshold, the LLM is asked to merge/update/delete. Set to 1.0 to disable.
|
||||
consolidation_threshold: float = 0.85,
|
||||
# Max existing records to compare against when checking for consolidation.
|
||||
consolidation_limit: int = 5,
|
||||
# -- Save defaults --
|
||||
# Importance assigned to new memories when no explicit value is given and
|
||||
# the LLM analysis path is skipped (all fields provided by the caller).
|
||||
default_importance: float = 0.5,
|
||||
# -- Recall depth control --
|
||||
# These thresholds govern the RecallFlow router that decides between
|
||||
# returning results immediately ("synthesize") vs. doing an extra
|
||||
# LLM-driven exploration round ("explore_deeper").
|
||||
# confidence >= confidence_threshold_high => always synthesize
|
||||
# confidence < confidence_threshold_low => explore deeper (if budget > 0)
|
||||
# complex query + confidence < complex_query_threshold => explore deeper
|
||||
confidence_threshold_high: float = 0.8,
|
||||
confidence_threshold_low: float = 0.5,
|
||||
complex_query_threshold: float = 0.7,
|
||||
# How many LLM-driven exploration rounds the RecallFlow is allowed to run.
|
||||
# 0 = always shallow (vector search only); higher = more thorough but slower.
|
||||
exploration_budget: int = 1,
|
||||
) -> None:
|
||||
"""Initialize Memory.
|
||||
|
||||
Args:
|
||||
llm: LLM for analysis (model name or BaseLLM instance).
|
||||
storage: Backend: "lancedb" or a StorageBackend instance.
|
||||
embedder: Embedding function; None => default OpenAI.
|
||||
recency_weight: Weight for recency in the composite relevance score.
|
||||
semantic_weight: Weight for semantic similarity in the composite relevance score.
|
||||
importance_weight: Weight for importance in the composite relevance score.
|
||||
recency_half_life_days: Recency score halves every N days (exponential decay).
|
||||
consolidation_threshold: Similarity above which consolidation is triggered on save.
|
||||
consolidation_limit: Max existing records to compare during consolidation.
|
||||
default_importance: Default importance when not provided or inferred.
|
||||
confidence_threshold_high: Recall confidence above which results are returned directly.
|
||||
confidence_threshold_low: Recall confidence below which deeper exploration is triggered.
|
||||
complex_query_threshold: For complex queries, explore deeper below this confidence.
|
||||
exploration_budget: Number of LLM-driven exploration rounds during deep recall.
|
||||
"""
|
||||
from crewai.llm import LLM
|
||||
|
||||
self._config = MemoryConfig(
|
||||
recency_weight=recency_weight,
|
||||
semantic_weight=semantic_weight,
|
||||
importance_weight=importance_weight,
|
||||
recency_half_life_days=recency_half_life_days,
|
||||
consolidation_threshold=consolidation_threshold,
|
||||
consolidation_limit=consolidation_limit,
|
||||
default_importance=default_importance,
|
||||
confidence_threshold_high=confidence_threshold_high,
|
||||
confidence_threshold_low=confidence_threshold_low,
|
||||
complex_query_threshold=complex_query_threshold,
|
||||
exploration_budget=exploration_budget,
|
||||
)
|
||||
self._llm = LLM(model=llm) if isinstance(llm, str) else llm
|
||||
if storage == "lancedb":
|
||||
self._storage = LanceDBStorage()
|
||||
elif isinstance(storage, str):
|
||||
self._storage = LanceDBStorage(path=storage)
|
||||
else:
|
||||
self._storage = storage
|
||||
self._embedder = embedder if embedder is not None else _default_embedder()
|
||||
|
||||
def remember(
|
||||
self,
|
||||
content: str,
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
importance: float | None = None,
|
||||
) -> MemoryRecord:
|
||||
"""Store content in memory. Infers scope/categories/importance via LLM when not provided.
|
||||
|
||||
Args:
|
||||
content: Text to remember.
|
||||
scope: Optional scope path; inferred if None.
|
||||
categories: Optional categories; inferred if None.
|
||||
metadata: Optional metadata; merged with LLM-extracted if scope/categories/importance inferred.
|
||||
importance: Optional importance 0-1; inferred if None.
|
||||
|
||||
Returns:
|
||||
The created MemoryRecord.
|
||||
|
||||
Raises:
|
||||
Exception: On save failure (events emitted).
|
||||
"""
|
||||
_source = "unified_memory"
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MemorySaveStartedEvent(
|
||||
value=content,
|
||||
metadata=metadata,
|
||||
source_type=_source,
|
||||
),
|
||||
)
|
||||
start = time.perf_counter()
|
||||
|
||||
need_analysis = (
|
||||
scope is None
|
||||
or categories is None
|
||||
or importance is None
|
||||
or (metadata is None and scope is None)
|
||||
)
|
||||
if need_analysis:
|
||||
existing_scopes = self._storage.list_scopes("/")
|
||||
if not existing_scopes:
|
||||
existing_scopes = ["/"]
|
||||
existing_categories = list(
|
||||
self._storage.list_categories(scope_prefix=None).keys()
|
||||
)
|
||||
analysis: MemoryAnalysis = analyze_for_save(
|
||||
content,
|
||||
existing_scopes,
|
||||
existing_categories,
|
||||
self._llm,
|
||||
)
|
||||
scope = scope or analysis.suggested_scope or "/"
|
||||
categories = categories if categories is not None else analysis.categories
|
||||
importance = (
|
||||
importance
|
||||
if importance is not None
|
||||
else analysis.importance
|
||||
)
|
||||
metadata = dict(
|
||||
metadata or {},
|
||||
**(
|
||||
analysis.extracted_metadata.model_dump()
|
||||
if analysis.extracted_metadata
|
||||
else {}
|
||||
),
|
||||
)
|
||||
else:
|
||||
scope = scope or "/"
|
||||
categories = categories or []
|
||||
metadata = metadata or {}
|
||||
importance = importance if importance is not None else self._config.default_importance
|
||||
|
||||
embedding = embed_text(self._embedder, content)
|
||||
from crewai.memory.consolidation_flow import ConsolidationFlow
|
||||
|
||||
flow = ConsolidationFlow(
|
||||
storage=self._storage,
|
||||
llm=self._llm,
|
||||
embedder=self._embedder,
|
||||
config=self._config,
|
||||
)
|
||||
flow.kickoff(
|
||||
inputs={
|
||||
"new_content": content,
|
||||
"new_embedding": embedding,
|
||||
"scope": scope,
|
||||
"categories": categories,
|
||||
"metadata": metadata or {},
|
||||
"importance": importance,
|
||||
},
|
||||
)
|
||||
record = flow.state.result_record
|
||||
if record is None:
|
||||
record = MemoryRecord(
|
||||
content=content,
|
||||
scope=scope,
|
||||
categories=categories,
|
||||
metadata=metadata or {},
|
||||
importance=importance,
|
||||
embedding=embedding if embedding else None,
|
||||
)
|
||||
self._storage.save([record])
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MemorySaveCompletedEvent(
|
||||
value=content,
|
||||
metadata=metadata or {},
|
||||
agent_role=None,
|
||||
save_time_ms=elapsed_ms,
|
||||
source_type=_source,
|
||||
),
|
||||
)
|
||||
return record
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MemorySaveFailedEvent(
|
||||
value=content,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type=_source,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def extract_memories(self, content: str) -> list[str]:
|
||||
"""Extract discrete memories from a raw content blob using the LLM.
|
||||
|
||||
This is a pure helper -- it does NOT store anything.
|
||||
Call remember() on each returned string to persist them.
|
||||
|
||||
Args:
|
||||
content: Raw text (e.g. task + result dump).
|
||||
|
||||
Returns:
|
||||
List of short, self-contained memory statements.
|
||||
"""
|
||||
return extract_memories_from_content(content, self._llm)
|
||||
|
||||
def recall(
|
||||
self,
|
||||
query: str,
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
depth: Literal["shallow", "deep", "auto"] = "auto",
|
||||
) -> list[MemoryMatch]:
|
||||
"""Retrieve relevant memories. Shallow = direct vector search; deep/auto = RecallFlow.
|
||||
|
||||
Args:
|
||||
query: Natural language query.
|
||||
scope: Optional scope prefix to search within.
|
||||
categories: Optional category filter.
|
||||
limit: Max number of results.
|
||||
depth: "shallow" for direct search, "deep" or "auto" for intelligent flow.
|
||||
|
||||
Returns:
|
||||
List of MemoryMatch, ordered by relevance.
|
||||
"""
|
||||
_source = "unified_memory"
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=None,
|
||||
source_type=_source,
|
||||
),
|
||||
)
|
||||
start = time.perf_counter()
|
||||
|
||||
if depth == "shallow":
|
||||
embedding = embed_text(self._embedder, query)
|
||||
if not embedding:
|
||||
results: list[MemoryMatch] = []
|
||||
else:
|
||||
raw = self._storage.search(
|
||||
embedding,
|
||||
scope_prefix=scope,
|
||||
categories=categories,
|
||||
limit=limit,
|
||||
min_score=0.0,
|
||||
)
|
||||
results = []
|
||||
for r, s in raw:
|
||||
composite, reasons = compute_composite_score(
|
||||
r, s, self._config
|
||||
)
|
||||
results.append(
|
||||
MemoryMatch(
|
||||
record=r,
|
||||
score=composite,
|
||||
match_reasons=reasons,
|
||||
)
|
||||
)
|
||||
results.sort(key=lambda m: m.score, reverse=True)
|
||||
else:
|
||||
flow = RecallFlow(
|
||||
storage=self._storage,
|
||||
llm=self._llm,
|
||||
config=self._config,
|
||||
)
|
||||
embedding = embed_text(self._embedder, query)
|
||||
flow.kickoff(
|
||||
inputs={
|
||||
"query": query,
|
||||
"scope": scope,
|
||||
"categories": categories or [],
|
||||
"limit": limit,
|
||||
"query_embedding": embedding,
|
||||
}
|
||||
)
|
||||
results = flow.state.final_results
|
||||
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=None,
|
||||
query_time_ms=elapsed_ms,
|
||||
source_type=_source,
|
||||
),
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=None,
|
||||
error=str(e),
|
||||
source_type=_source,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def forget(
|
||||
self,
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
older_than: datetime | None = None,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
record_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
"""Delete memories matching criteria.
|
||||
|
||||
Returns:
|
||||
Number of records deleted.
|
||||
"""
|
||||
return self._storage.delete(
|
||||
scope_prefix=scope,
|
||||
categories=categories,
|
||||
record_ids=record_ids,
|
||||
older_than=older_than,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
record_id: str,
|
||||
content: str | None = None,
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
importance: float | None = None,
|
||||
) -> MemoryRecord:
|
||||
"""Update an existing memory record by ID.
|
||||
|
||||
Args:
|
||||
record_id: ID of the record to update.
|
||||
content: New content; re-embedded if provided.
|
||||
scope: New scope path.
|
||||
categories: New categories.
|
||||
metadata: New metadata.
|
||||
importance: New importance score.
|
||||
|
||||
Returns:
|
||||
The updated MemoryRecord.
|
||||
|
||||
Raises:
|
||||
ValueError: If the storage does not support get_record or the record is not found.
|
||||
"""
|
||||
get_record = getattr(self._storage, "get_record", None)
|
||||
if get_record is None:
|
||||
raise ValueError("This storage backend does not support update by ID")
|
||||
existing = get_record(record_id)
|
||||
if existing is None:
|
||||
raise ValueError(f"Record not found: {record_id}")
|
||||
now = datetime.utcnow()
|
||||
updates: dict[str, Any] = {"last_accessed": now}
|
||||
if content is not None:
|
||||
updates["content"] = content
|
||||
embedding = embed_text(self._embedder, content)
|
||||
updates["embedding"] = embedding if embedding else existing.embedding
|
||||
if scope is not None:
|
||||
updates["scope"] = scope
|
||||
if categories is not None:
|
||||
updates["categories"] = categories
|
||||
if metadata is not None:
|
||||
updates["metadata"] = metadata
|
||||
if importance is not None:
|
||||
updates["importance"] = importance
|
||||
updated = existing.model_copy(update=updates)
|
||||
self._storage.update(updated)
|
||||
return updated
|
||||
|
||||
def scope(self, path: str) -> Any:
|
||||
"""Return a scoped view of this memory."""
|
||||
from crewai.memory.memory_scope import MemoryScope
|
||||
|
||||
return MemoryScope(memory=self, root_path=path)
|
||||
|
||||
def slice(
|
||||
self,
|
||||
scopes: list[str],
|
||||
categories: list[str] | None = None,
|
||||
read_only: bool = True,
|
||||
) -> Any:
|
||||
"""Return a multi-scope view (slice) of this memory."""
|
||||
from crewai.memory.memory_scope import MemorySlice
|
||||
|
||||
return MemorySlice(
|
||||
memory=self,
|
||||
scopes=scopes,
|
||||
categories=categories,
|
||||
read_only=read_only,
|
||||
)
|
||||
|
||||
def list_scopes(self, path: str = "/") -> list[str]:
|
||||
"""List immediate child scopes under path."""
|
||||
return self._storage.list_scopes(path)
|
||||
|
||||
def list_records(
|
||||
self, scope: str | None = None, limit: int = 200
|
||||
) -> list[MemoryRecord]:
|
||||
"""List records in a scope, newest first."""
|
||||
return self._storage.list_records(scope_prefix=scope, limit=limit)
|
||||
|
||||
def info(self, path: str = "/") -> ScopeInfo:
|
||||
"""Return scope info for path."""
|
||||
return self._storage.get_scope_info(path)
|
||||
|
||||
def tree(self, path: str = "/", max_depth: int = 3) -> str:
|
||||
"""Return a formatted tree of scopes (string)."""
|
||||
lines: list[str] = []
|
||||
|
||||
def _walk(p: str, depth: int, prefix: str) -> None:
|
||||
if depth > max_depth:
|
||||
return
|
||||
info = self._storage.get_scope_info(p)
|
||||
lines.append(f"{prefix}{p or '/'} ({info.record_count} records)")
|
||||
for child in info.child_scopes[:20]:
|
||||
_walk(child, depth + 1, prefix + " ")
|
||||
|
||||
_walk(path.rstrip("/") or "/", 0, "")
|
||||
return "\n".join(lines) if lines else f"{path or '/'} (0 records)"
|
||||
|
||||
def list_categories(self, path: str | None = None) -> dict[str, int]:
|
||||
"""List categories and counts; path=None means global."""
|
||||
return self._storage.list_categories(scope_prefix=path)
|
||||
|
||||
def reset(self, scope: str | None = None) -> None:
|
||||
"""Reset (delete all) memories in scope. None = all."""
|
||||
self._storage.reset(scope_prefix=scope)
|
||||
|
||||
async def aextract_memories(self, content: str) -> list[str]:
|
||||
"""Async variant of extract_memories."""
|
||||
return self.extract_memories(content)
|
||||
|
||||
async def aremember(
|
||||
self,
|
||||
content: str,
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
importance: float | None = None,
|
||||
) -> MemoryRecord:
|
||||
"""Async remember: delegates to sync for now."""
|
||||
return self.remember(
|
||||
content,
|
||||
scope=scope,
|
||||
categories=categories,
|
||||
metadata=metadata,
|
||||
importance=importance,
|
||||
)
|
||||
|
||||
async def arecall(
|
||||
self,
|
||||
query: str,
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
depth: Literal["shallow", "deep", "auto"] = "auto",
|
||||
) -> list[MemoryMatch]:
|
||||
"""Async recall: delegates to sync for now."""
|
||||
return self.recall(
|
||||
query,
|
||||
scope=scope,
|
||||
categories=categories,
|
||||
limit=limit,
|
||||
depth=depth,
|
||||
)
|
||||
@@ -57,6 +57,16 @@
|
||||
"default_action": "Please provide a detailed description of this image, including all visual elements, context, and any notable details you can observe."
|
||||
}
|
||||
},
|
||||
"memory": {
|
||||
"save_system": "You analyze content to be stored in a hierarchical memory system.\nGiven the content and the existing scopes and categories, output:\n1. suggested_scope: The best matching existing scope path, or a new path if none fit (use / for root).\n2. categories: A list of categories (reuse existing when relevant, add new ones if needed).\n3. importance: A number from 0.0 to 1.0 indicating how significant this memory is.\n4. extracted_metadata: A JSON object with any entities, dates, or topics you can extract.",
|
||||
"query_system": "You analyze a user's query for searching memory.\nGiven the query and available scopes, output:\n1. keywords: Key entities or keywords to use for filtering.\n2. time_hints: Any time or recency references in the query.\n3. suggested_scopes: Which of the available scopes are most relevant (can be empty for \"all\").\n4. complexity: 'simple' if the query is a straightforward lookup, 'complex' if it requires aggregation or multi-step reasoning.",
|
||||
"extract_memories_system": "You extract discrete, reusable memory statements from raw content (e.g. a task description and its result).\n\nFor the given content, output a list of memory statements. Each memory must:\n- Be one clear sentence or short statement\n- Be understandable without the original context\n- Capture a decision, fact, outcome, preference, lesson, or observation worth remembering\n- NOT be a vague summary or a restatement of the task description\n- NOT duplicate the same idea in different words\n\nIf there is nothing worth remembering (e.g. empty result, no decisions or facts), return an empty list.\nOutput a JSON object with a single key \"memories\" whose value is a list of strings.",
|
||||
"consolidation_system": "You are comparing new content against existing memories to decide how to consolidate them.\n\nFor each existing memory, choose one action:\n- 'keep': The existing memory is still accurate and not redundant with the new content.\n- 'update': The existing memory should be updated with new information. Provide the updated content.\n- 'delete': The existing memory is outdated, superseded, or contradicted by the new content.\n\nAlso decide whether the new content should be inserted as a separate memory:\n- insert_new=true: The new content adds information not fully captured by existing memories (even after updates).\n- insert_new=false: The new content is fully captured by the existing memories (after any updates).\n\nBe conservative: prefer 'keep' when unsure. Only 'update' or 'delete' when there is a clear contradiction, supersession, or redundancy.",
|
||||
"extract_memories_user": "Content:\n{content}\n\nExtract memory statements as described. Return structured output.",
|
||||
"save_user": "Content to store:\n{content}\n\nExisting scopes: {existing_scopes}\nExisting categories: {existing_categories}\n\nReturn the analysis as structured output.",
|
||||
"query_user": "Query: {query}\n\nAvailable scopes: {available_scopes}\n{scope_desc}\n\nReturn the analysis as structured output.",
|
||||
"consolidation_user": "New content to consider storing:\n{new_content}\n\nExisting similar memories:\n{records_summary}\n\nReturn the consolidation plan as structured output."
|
||||
},
|
||||
"reasoning": {
|
||||
"initial_plan": "You are {role}, a professional with the following background: {backstory}\n\nYour primary goal is: {goal}\n\nAs {role}, you are creating a strategic plan for a task that requires your expertise and unique perspective.",
|
||||
"refine_plan": "You are {role}, a professional with the following background: {backstory}\n\nYour primary goal is: {goal}\n\nAs {role}, you are refining a strategic plan for a task that requires your expertise and unique perspective.",
|
||||
|
||||
@@ -86,10 +86,21 @@ class I18N(BaseModel):
|
||||
"""
|
||||
return self.retrieve("tools", tool)
|
||||
|
||||
def memory(self, key: str) -> str:
|
||||
"""Retrieve a memory prompt by key.
|
||||
|
||||
Args:
|
||||
key: The key of the memory prompt to retrieve.
|
||||
|
||||
Returns:
|
||||
The memory prompt as a string.
|
||||
"""
|
||||
return self.retrieve("memory", key)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
kind: Literal[
|
||||
"slices", "errors", "tools", "reasoning", "hierarchical_manager_agent"
|
||||
"slices", "errors", "tools", "reasoning", "hierarchical_manager_agent", "memory"
|
||||
],
|
||||
key: str,
|
||||
) -> str:
|
||||
|
||||
@@ -372,10 +372,7 @@ class TestFlowInvoke:
|
||||
task.human_input = False
|
||||
|
||||
crew = Mock()
|
||||
crew._short_term_memory = None
|
||||
crew._long_term_memory = None
|
||||
crew._entity_memory = None
|
||||
crew._external_memory = None
|
||||
crew._memory = None
|
||||
|
||||
agent = Mock()
|
||||
agent.role = "Test"
|
||||
@@ -398,14 +395,10 @@ class TestFlowInvoke:
|
||||
}
|
||||
|
||||
@patch.object(AgentExecutor, "kickoff")
|
||||
@patch.object(AgentExecutor, "_create_short_term_memory")
|
||||
@patch.object(AgentExecutor, "_create_long_term_memory")
|
||||
@patch.object(AgentExecutor, "_create_external_memory")
|
||||
@patch.object(AgentExecutor, "_save_to_memory")
|
||||
def test_invoke_success(
|
||||
self,
|
||||
mock_external_memory,
|
||||
mock_long_term_memory,
|
||||
mock_short_term_memory,
|
||||
mock_save_to_memory,
|
||||
mock_kickoff,
|
||||
mock_dependencies,
|
||||
):
|
||||
@@ -425,9 +418,7 @@ class TestFlowInvoke:
|
||||
|
||||
assert result == {"output": "Final result"}
|
||||
mock_kickoff.assert_called_once()
|
||||
mock_short_term_memory.assert_called_once()
|
||||
mock_long_term_memory.assert_called_once()
|
||||
mock_external_memory.assert_called_once()
|
||||
mock_save_to_memory.assert_called_once()
|
||||
|
||||
@patch.object(AgentExecutor, "kickoff")
|
||||
def test_invoke_failure_no_agent_finish(self, mock_kickoff, mock_dependencies):
|
||||
@@ -443,14 +434,10 @@ class TestFlowInvoke:
|
||||
executor.invoke(inputs)
|
||||
|
||||
@patch.object(AgentExecutor, "kickoff")
|
||||
@patch.object(AgentExecutor, "_create_short_term_memory")
|
||||
@patch.object(AgentExecutor, "_create_long_term_memory")
|
||||
@patch.object(AgentExecutor, "_create_external_memory")
|
||||
@patch.object(AgentExecutor, "_save_to_memory")
|
||||
def test_invoke_with_system_prompt(
|
||||
self,
|
||||
mock_external_memory,
|
||||
mock_long_term_memory,
|
||||
mock_short_term_memory,
|
||||
mock_save_to_memory,
|
||||
mock_kickoff,
|
||||
mock_dependencies,
|
||||
):
|
||||
@@ -470,9 +457,7 @@ class TestFlowInvoke:
|
||||
|
||||
inputs = {"input": "test", "tool_names": "", "tools": ""}
|
||||
result = executor.invoke(inputs)
|
||||
mock_short_term_memory.assert_called_once()
|
||||
mock_long_term_memory.assert_called_once()
|
||||
mock_external_memory.assert_called_once()
|
||||
mock_save_to_memory.assert_called_once()
|
||||
mock_kickoff.assert_called_once()
|
||||
|
||||
assert result == {"output": "Done"}
|
||||
|
||||
@@ -95,16 +95,14 @@ class TestAsyncAgentExecutor:
|
||||
),
|
||||
):
|
||||
with patch.object(executor, "_show_start_logs"):
|
||||
with patch.object(executor, "_create_short_term_memory"):
|
||||
with patch.object(executor, "_create_long_term_memory"):
|
||||
with patch.object(executor, "_create_external_memory"):
|
||||
result = await executor.ainvoke(
|
||||
{
|
||||
"input": "test input",
|
||||
"tool_names": "",
|
||||
"tools": "",
|
||||
}
|
||||
)
|
||||
with patch.object(executor, "_save_to_memory"):
|
||||
result = await executor.ainvoke(
|
||||
{
|
||||
"input": "test input",
|
||||
"tool_names": "",
|
||||
"tools": "",
|
||||
}
|
||||
)
|
||||
|
||||
assert result == {"output": expected_output}
|
||||
|
||||
@@ -273,16 +271,14 @@ class TestAsyncAgentExecutor:
|
||||
):
|
||||
with patch.object(executor, "_show_start_logs"):
|
||||
with patch.object(executor, "_show_logs"):
|
||||
with patch.object(executor, "_create_short_term_memory"):
|
||||
with patch.object(executor, "_create_long_term_memory"):
|
||||
with patch.object(executor, "_create_external_memory"):
|
||||
return await executor.ainvoke(
|
||||
{
|
||||
"input": f"test {executor_id}",
|
||||
"tool_names": "",
|
||||
"tools": "",
|
||||
}
|
||||
)
|
||||
with patch.object(executor, "_save_to_memory"):
|
||||
return await executor.ainvoke(
|
||||
{
|
||||
"input": f"test {executor_id}",
|
||||
"tool_names": "",
|
||||
"tools": "",
|
||||
}
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
create_and_run_executor(1),
|
||||
|
||||
@@ -16,6 +16,7 @@ import pytest
|
||||
from crewai import LLM, Agent
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
|
||||
# A simple test tool
|
||||
@@ -1064,3 +1065,97 @@ def test_lite_agent_verbose_false_suppresses_printer_output():
|
||||
agent2.kickoff("Say hello")
|
||||
|
||||
mock_printer.print.assert_not_called()
|
||||
|
||||
|
||||
# --- LiteAgent memory integration ---
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:LiteAgent is deprecated")
|
||||
def test_lite_agent_memory_none_default():
|
||||
"""With memory=None (default), _memory is None and no memory is used."""
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.call.return_value = "Final Answer: Ok"
|
||||
mock_llm.stop = []
|
||||
mock_llm.get_token_usage_summary.return_value = UsageMetrics(
|
||||
total_tokens=10,
|
||||
prompt_tokens=5,
|
||||
completion_tokens=5,
|
||||
cached_prompt_tokens=0,
|
||||
successful_requests=1,
|
||||
)
|
||||
agent = LiteAgent(
|
||||
role="Test",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm=mock_llm,
|
||||
memory=None,
|
||||
verbose=False,
|
||||
)
|
||||
assert agent._memory is None
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:LiteAgent is deprecated")
|
||||
def test_lite_agent_memory_true_resolves_to_default_memory():
|
||||
"""With memory=True, _memory is a Memory instance."""
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.call.return_value = "Final Answer: Ok"
|
||||
mock_llm.stop = []
|
||||
mock_llm.get_token_usage_summary.return_value = UsageMetrics(
|
||||
total_tokens=10,
|
||||
prompt_tokens=5,
|
||||
completion_tokens=5,
|
||||
cached_prompt_tokens=0,
|
||||
successful_requests=1,
|
||||
)
|
||||
agent = LiteAgent(
|
||||
role="Test",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm=mock_llm,
|
||||
memory=True,
|
||||
verbose=False,
|
||||
)
|
||||
assert agent._memory is not None
|
||||
assert isinstance(agent._memory, Memory)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:LiteAgent is deprecated")
|
||||
def test_lite_agent_memory_instance_recall_and_save_called():
|
||||
"""With a custom memory instance, kickoff calls recall and then extract_memories/remember."""
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.call.return_value = "Final Answer: The answer is 42."
|
||||
mock_llm.stop = []
|
||||
mock_llm.supports_stop_words.return_value = False
|
||||
mock_llm.get_token_usage_summary.return_value = UsageMetrics(
|
||||
total_tokens=10,
|
||||
prompt_tokens=5,
|
||||
completion_tokens=5,
|
||||
cached_prompt_tokens=0,
|
||||
successful_requests=1,
|
||||
)
|
||||
mock_memory = Mock()
|
||||
mock_memory.recall.return_value = []
|
||||
mock_memory.extract_memories.return_value = ["Fact one.", "Fact two."]
|
||||
|
||||
agent = LiteAgent(
|
||||
role="Test",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm=mock_llm,
|
||||
memory=mock_memory,
|
||||
verbose=False,
|
||||
)
|
||||
assert agent._memory is mock_memory
|
||||
|
||||
agent.kickoff("What is the answer?")
|
||||
|
||||
mock_memory.recall.assert_called_once()
|
||||
call_kw = mock_memory.recall.call_args[1]
|
||||
assert call_kw.get("limit") == 10
|
||||
assert call_kw.get("depth") == "shallow"
|
||||
mock_memory.extract_memories.assert_called_once()
|
||||
assert mock_memory.remember.call_count == 2
|
||||
mock_memory.remember.assert_any_call("Fact one.")
|
||||
mock_memory.remember.assert_any_call("Fact two.")
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,191 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: !!binary |
|
||||
Ct8MCiQKIgoMc2VydmljZS5uYW1lEhIKEGNyZXdBSS10ZWxlbWV0cnkStgwKEgoQY3Jld2FpLnRl
|
||||
bGVtZXRyeRKcCAoQjin/Su47zAwLq3Hv6yv8GhIImRMfAPs+FOMqDENyZXcgQ3JlYXRlZDABOYCY
|
||||
xbgUrDUYQVie07gUrDUYShsKDmNyZXdhaV92ZXJzaW9uEgkKBzAuMTE0LjBKGgoOcHl0aG9uX3Zl
|
||||
cnNpb24SCAoGMy4xMi45Si4KCGNyZXdfa2V5EiIKIDA3YTcxNzY4Y2M0YzkzZWFiM2IzMWUzYzhk
|
||||
MjgzMmM2SjEKB2NyZXdfaWQSJgokY2UyMGFlNWYtZmMyNy00YWJhLWExYWMtNzUwY2ZhZmMwMTE4
|
||||
ShwKDGNyZXdfcHJvY2VzcxIMCgpzZXF1ZW50aWFsShEKC2NyZXdfbWVtb3J5EgIQAEoaChRjcmV3
|
||||
X251bWJlcl9vZl90YXNrcxICGAFKGwoVY3Jld19udW1iZXJfb2ZfYWdlbnRzEgIYAUo6ChBjcmV3
|
||||
X2ZpbmdlcnByaW50EiYKJDQ4NGFmZDhjLTczMmEtNGM1Ni1hZjk2LTU2MzkwMjNmYjhjOUo7Chtj
|
||||
cmV3X2ZpbmdlcnByaW50X2NyZWF0ZWRfYXQSHAoaMjAyNS0wNC0xMlQxNzoyNzoxNS42NzMyMjNK
|
||||
0AIKC2NyZXdfYWdlbnRzEsACCr0CW3sia2V5IjogIjAyZGYxM2UzNjcxMmFiZjUxZDIzOGZlZWJh
|
||||
YjFjYTI2IiwgImlkIjogImYyYjZkYTU1LTNiMGItNDZiNy05Mzk5LWE5NDJmYjQ4YzU2OSIsICJy
|
||||
b2xlIjogIlJlc2VhcmNoZXIiLCAidmVyYm9zZT8iOiB0cnVlLCAibWF4X2l0ZXIiOiAyNSwgIm1h
|
||||
eF9ycG0iOiBudWxsLCAiZnVuY3Rpb25fY2FsbGluZ19sbG0iOiAiIiwgImxsbSI6ICJncHQtNG8t
|
||||
bWluaSIsICJkZWxlZ2F0aW9uX2VuYWJsZWQ/IjogZmFsc2UsICJhbGxvd19jb2RlX2V4ZWN1dGlv
|
||||
bj8iOiBmYWxzZSwgIm1heF9yZXRyeV9saW1pdCI6IDIsICJ0b29sc19uYW1lcyI6IFtdfV1K/wEK
|
||||
CmNyZXdfdGFza3MS8AEK7QFbeyJrZXkiOiAiN2I0MmRmM2MzYzc0YzIxYzg5NDgwZTBjMDcwNTM4
|
||||
NWYiLCAiaWQiOiAiYmE1MjFjNDgtYzcwNS00MDRlLWE5MDktMjkwZGM0NTlkOThkIiwgImFzeW5j
|
||||
X2V4ZWN1dGlvbj8iOiBmYWxzZSwgImh1bWFuX2lucHV0PyI6IGZhbHNlLCAiYWdlbnRfcm9sZSI6
|
||||
ICJSZXNlYXJjaGVyIiwgImFnZW50X2tleSI6ICIwMmRmMTNlMzY3MTJhYmY1MWQyMzhmZWViYWIx
|
||||
Y2EyNiIsICJ0b29sc19uYW1lcyI6IFtdfV16AhgBhQEAAQAAEoAEChAmCOpHN6fX3l0shQvTLjrB
|
||||
EgjLTyt4A1p7wyoMVGFzayBDcmVhdGVkMAE5gN7juBSsNRhBmFfkuBSsNRhKLgoIY3Jld19rZXkS
|
||||
IgogMDdhNzE3NjhjYzRjOTNlYWIzYjMxZTNjOGQyODMyYzZKMQoHY3Jld19pZBImCiRjZTIwYWU1
|
||||
Zi1mYzI3LTRhYmEtYTFhYy03NTBjZmFmYzAxMThKLgoIdGFza19rZXkSIgogN2I0MmRmM2MzYzc0
|
||||
YzIxYzg5NDgwZTBjMDcwNTM4NWZKMQoHdGFza19pZBImCiRiYTUyMWM0OC1jNzA1LTQwNGUtYTkw
|
||||
OS0yOTBkYzQ1OWQ5OGRKOgoQY3Jld19maW5nZXJwcmludBImCiQ0ODRhZmQ4Yy03MzJhLTRjNTYt
|
||||
YWY5Ni01NjM5MDIzZmI4YzlKOgoQdGFza19maW5nZXJwcmludBImCiRhMDcyNjgwNC05ZjIwLTQw
|
||||
ODgtYWFmOC1iNzhkYTUyNmM3NjlKOwobdGFza19maW5nZXJwcmludF9jcmVhdGVkX2F0EhwKGjIw
|
||||
MjUtMDQtMTJUMTc6Mjc6MTUuNjczMTgxSjsKEWFnZW50X2ZpbmdlcnByaW50EiYKJDNiZDE2MmNm
|
||||
LWNmMWQtNGUwZi04ZmIzLTk3MDljMDkyNmM4ZHoCGAGFAQABAAA=
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '1634'
|
||||
Content-Type:
|
||||
- application/x-protobuf
|
||||
User-Agent:
|
||||
- OTel-OTLP-Exporter-Python/1.31.1
|
||||
method: POST
|
||||
uri: https://telemetry.crewai.com:4319/v1/traces
|
||||
response:
|
||||
body:
|
||||
string: "\n\0"
|
||||
headers:
|
||||
Content-Length:
|
||||
- '2'
|
||||
Content-Type:
|
||||
- application/x-protobuf
|
||||
Date:
|
||||
- Sat, 12 Apr 2025 20:27:16 GMT
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Researcher. You are
|
||||
a researcher at a leading tech think tank.\nYour personal goal is: Search relevant
|
||||
data and provide results\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Perform a search
|
||||
on specific topics.\n\nThis is the expected criteria for your final answer:
|
||||
A list of relevant URLs based on the search query.\nyou MUST return the actual
|
||||
complete content as the final answer, not a summary.\n\n# Useful context: \nExternal
|
||||
memories:\n\n\nBegin! This is VERY important to you, use the tools available
|
||||
and give your best Final Answer, your job depends on it!\n\nThought:"}], "model":
|
||||
"gpt-4o-mini", "stop": ["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '989'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- __cf_bm=nSje5Zn_Lk69BDG85XIauC2hrZjGl0pR2sel9__KWGw-1744489610-1.0.1.1-CPlAgcgTAE30uWrbi_2wiCWrbRDRWiaa.YuQMgST42DLDVg_wdNlJMDQT3Lsqk.g.BO68A66TTirWA0blQaQw.9xdBbPwKO609_ftjdwi5U;
|
||||
_cfuvid=XLC52GLAWCOeWn2vI379CnSGKjPa7f.qr2vSAQ_R66M-1744489610542-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.68.2
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.68.2
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-BLbjXyMvmR8ctf0sqhp7F1ePskveM\",\n \"object\"\
|
||||
: \"chat.completion\",\n \"created\": 1744489635,\n \"model\": \"gpt-4o-mini-2024-07-18\"\
|
||||
,\n \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \
|
||||
\ \"role\": \"assistant\",\n \"content\": \"I now can give a great\
|
||||
\ answer \\nFinal Answer: Here is a list of relevant URLs based on the search\
|
||||
\ query:\\n\\n1. **Artificial Intelligence in Healthcare**\\n - https://www.healthit.gov/topic/scientific-initiatives/ai-healthcare\\\
|
||||
n - https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7317789/\\n - https://www.forbes.com/sites/bernardmarr/2021/10/18/the-top-5-ways-ai-is-being-used-in-healthcare/?sh=3edf5df51c9c\\\
|
||||
n\\n2. **Blockchain Technology in Supply Chain Management**\\n - https://www.ibm.com/blockchain/supply-chain\\\
|
||||
n - https://www.gartner.com/en/newsroom/press-releases/2021-06-23-gartner-says-three-use-cases-for-blockchain-in-supply-chain-are-scaling\\\
|
||||
n - https://www2.deloitte.com/us/en/insights/industry/retail-distribution/blockchain-in-supply-chain.html\\\
|
||||
n\\n3. **Renewable Energy Innovations**\\n - https://www.irena.org/publications/2020/Sep/Renewable-Power-Generation-Costs-in-2020\\\
|
||||
n - https://www.nrel.gov/docs/fy20osti/77021.pdf\\n - https://www.cnbc.com/2021/11/03/renewable-energy-could-get-its-first-taste-of-markets-in-2021.html\\\
|
||||
n\\n4. **7G Technology Developments**\\n - https://www.sciencedirect.com/science/article/pii/S1389128619308189\\\
|
||||
n - https://www.forbes.com/sites/bernardmarr/2021/11/01/what-is-7g-technology-a-beginners-guide-to-the-future-of-mobile-communications/?sh=51b8a7e1464a\\\
|
||||
n - https://www.ericsson.com/en/reports-and-research/reports/7g-networks-a-powerful-future-for-connected-society\\\
|
||||
n\\n5. **Impact of Quantum Computing on Cybersecurity**\\n - https://www.ibm.com/blogs/research/2021/09/quantum-computing-cybersecurity/\\\
|
||||
n - https://www.sciencedirect.com/science/article/pii/S0167739X21000072\\\
|
||||
n - https://www.techrepublic.com/article/how-quantum-computing-will-change-cybersecurity/\\\
|
||||
n\\nThese URLs should provide comprehensive information on the topics searched,\
|
||||
\ providing valuable insights and data for your research needs.\",\n \
|
||||
\ \"refusal\": null,\n \"annotations\": []\n },\n \"logprobs\"\
|
||||
: null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n\
|
||||
\ \"prompt_tokens\": 185,\n \"completion_tokens\": 534,\n \"total_tokens\"\
|
||||
: 719,\n \"prompt_tokens_details\": {\n \"cached_tokens\": 0,\n \
|
||||
\ \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n \
|
||||
\ \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\"\
|
||||
: 0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\"\
|
||||
: \"default\",\n \"system_fingerprint\": \"fp_80cf447eee\"\n}\n"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 92f576a01d3b7e05-GRU
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sat, 12 Apr 2025 20:27:24 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '8805'
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999788'
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_7c2d313d0b5997e903553a782b2afa25
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -85,39 +85,41 @@ def test_reset_all_memories(mock_get_crews, runner):
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
def test_reset_short_term_memories(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-s"])
|
||||
def test_reset_memory(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="short")
|
||||
crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Short term memory has been reset." in result.output
|
||||
f"[Crew ({crew.name})] Memory has been reset." in result.output
|
||||
)
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
def test_reset_entity_memories(mock_get_crews, runner):
|
||||
def test_reset_short_flag_deprecated_maps_to_memory(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-s"])
|
||||
assert "deprecated" in result.output.lower()
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
assert f"[Crew ({crew.name})] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_entity_flag_deprecated_maps_to_memory(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-e"])
|
||||
call_count = 0
|
||||
assert "deprecated" in result.output.lower()
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="entity")
|
||||
assert f"[Crew ({crew.name})] Entity memory has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
assert f"[Crew ({crew.name})] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_long_term_memories(mock_get_crews, runner):
|
||||
def test_reset_long_flag_deprecated_maps_to_memory(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-l"])
|
||||
call_count = 0
|
||||
assert "deprecated" in result.output.lower()
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_called_once_with(command_type="long")
|
||||
assert f"[Crew ({crew.name})] Long term memory has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
assert f"[Crew ({crew.name})] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_kickoff_outputs(mock_get_crews, runner):
|
||||
@@ -134,17 +136,14 @@ def test_reset_kickoff_outputs(mock_get_crews, runner):
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
|
||||
def test_reset_multiple_memory_flags(mock_get_crews, runner):
|
||||
def test_reset_multiple_legacy_flags_collapsed_to_single_memory_reset(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-s", "-l"])
|
||||
# Both legacy flags collapse to a single --memory reset
|
||||
assert "deprecated" in result.output.lower()
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
crew.reset_memories.assert_has_calls(
|
||||
[mock.call(command_type="long"), mock.call(command_type="short")]
|
||||
)
|
||||
assert (
|
||||
f"[Crew ({crew.name})] Long term memory has been reset.\n"
|
||||
f"[Crew ({crew.name})] Short term memory has been reset.\n" in result.output
|
||||
)
|
||||
crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
assert f"[Crew ({crew.name})] Memory has been reset." in result.output
|
||||
call_count += 1
|
||||
|
||||
assert call_count == 1, "reset_memories should have been called once"
|
||||
|
||||
@@ -1,496 +0,0 @@
|
||||
"""Tests for async memory operations."""
|
||||
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Fixture to create a mock agent."""
|
||||
return Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
tools=[],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task(mock_agent):
|
||||
"""Fixture to create a mock task."""
|
||||
return Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=mock_agent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_term_memory(mock_agent, mock_task):
|
||||
"""Fixture to create a ShortTermMemory instance."""
|
||||
return ShortTermMemory(crew=Crew(agents=[mock_agent], tasks=[mock_task]))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory(tmp_path):
|
||||
"""Fixture to create a LongTermMemory instance."""
|
||||
db_path = str(tmp_path / "test_ltm.db")
|
||||
return LongTermMemory(path=db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity_memory(tmp_path, mock_agent, mock_task):
|
||||
"""Fixture to create an EntityMemory instance."""
|
||||
return EntityMemory(
|
||||
crew=Crew(agents=[mock_agent], tasks=[mock_task]),
|
||||
path=str(tmp_path / "test_entities"),
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncShortTermMemory:
|
||||
"""Tests for async ShortTermMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_emits_events(self, short_term_memory):
|
||||
"""Test that asave emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
await short_term_memory.asave(
|
||||
value="async test value",
|
||||
metadata={"task": "async_test_task"},
|
||||
)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].value == "async test value"
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "short_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, short_term_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
with patch.object(short_term_memory.storage, "asearch", new_callable=AsyncMock, return_value=[]):
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await short_term_memory.asearch(
|
||||
query="async test query",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].query == "async test query"
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "short_term_memory"
|
||||
|
||||
|
||||
class TestAsyncLongTermMemory:
|
||||
"""Tests for async LongTermMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_emits_events(self, long_term_memory):
|
||||
"""Test that asave emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
item = LongTermMemoryItem(
|
||||
task="async test task",
|
||||
agent="test_agent",
|
||||
expected_output="test output",
|
||||
datetime="2024-01-01T00:00:00",
|
||||
quality=0.9,
|
||||
metadata={"task": "async test task", "quality": 0.9},
|
||||
)
|
||||
|
||||
await long_term_memory.asave(item)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "long_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, long_term_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await long_term_memory.asearch(task="async test task", latest_n=3)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "long_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_and_asearch_integration(self, long_term_memory):
|
||||
"""Test that asave followed by asearch works correctly."""
|
||||
item = LongTermMemoryItem(
|
||||
task="integration test task",
|
||||
agent="test_agent",
|
||||
expected_output="test output",
|
||||
datetime="2024-01-01T00:00:00",
|
||||
quality=0.9,
|
||||
metadata={"task": "integration test task", "quality": 0.9},
|
||||
)
|
||||
|
||||
await long_term_memory.asave(item)
|
||||
results = await long_term_memory.asearch(task="integration test task", latest_n=1)
|
||||
|
||||
assert results is not None
|
||||
assert len(results) == 1
|
||||
assert results[0]["metadata"]["agent"] == "test_agent"
|
||||
|
||||
|
||||
class TestAsyncEntityMemory:
|
||||
"""Tests for async EntityMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_single_item_emits_events(self, entity_memory):
|
||||
"""Test that asave with a single item emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
item = EntityMemoryItem(
|
||||
name="TestEntity",
|
||||
type="Person",
|
||||
description="A test entity for async operations",
|
||||
relationships="Related to other test entities",
|
||||
)
|
||||
|
||||
await entity_memory.asave(item)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "entity_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, entity_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await entity_memory.asearch(query="TestEntity", limit=5, score_threshold=0.6)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "entity_memory"
|
||||
|
||||
|
||||
class TestAsyncContextualMemory:
|
||||
"""Tests for async ContextualMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_with_empty_query(self, mock_task):
|
||||
"""Test that abuild_context_for_task returns empty string for empty query."""
|
||||
mock_task.description = ""
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_with_none_memories(self, mock_task):
|
||||
"""Test that abuild_context_for_task handles None memory sources."""
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "some context")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_aggregates_results(self, mock_agent, mock_task):
|
||||
"""Test that abuild_context_for_task aggregates results from all memory sources."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(return_value=[{"content": "STM insight"}])
|
||||
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(
|
||||
return_value=[{"metadata": {"suggestions": ["LTM suggestion"]}}]
|
||||
)
|
||||
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(return_value=[{"content": "Entity info"}])
|
||||
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(return_value=[{"content": "External memory"}])
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=mock_ltm,
|
||||
em=mock_em,
|
||||
exm=mock_exm,
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "additional context")
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "STM insight" in result
|
||||
assert "Historical Data:" in result
|
||||
assert "LTM suggestion" in result
|
||||
assert "Entities:" in result
|
||||
assert "Entity info" in result
|
||||
assert "External memories:" in result
|
||||
assert "External memory" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_stm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_stm_context returns properly formatted results."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "First insight"},
|
||||
{"content": "Second insight"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_stm_context("test query")
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "- First insight" in result
|
||||
assert "- Second insight" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_ltm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_ltm_context returns properly formatted results."""
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"metadata": {"suggestions": ["Suggestion 1", "Suggestion 2"]}},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=mock_ltm,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_ltm_context("test task")
|
||||
|
||||
assert "Historical Data:" in result
|
||||
assert "- Suggestion 1" in result
|
||||
assert "- Suggestion 2" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_entity_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_entity_context returns properly formatted results."""
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "Entity A details"},
|
||||
{"content": "Entity B details"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=mock_em,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_entity_context("test query")
|
||||
|
||||
assert "Entities:" in result
|
||||
assert "- Entity A details" in result
|
||||
assert "- Entity B details" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_external_context_returns_formatted_results(self):
|
||||
"""Test that _afetch_external_context returns properly formatted results."""
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "External data 1"},
|
||||
{"content": "External data 2"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=mock_exm,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_external_context("test query")
|
||||
|
||||
assert "External memories:" in result
|
||||
assert "- External data 1" in result
|
||||
assert "- External data 2" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_methods_return_empty_for_empty_results(self):
|
||||
"""Test that async fetch methods return empty string for no results."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=mock_ltm,
|
||||
em=mock_em,
|
||||
exm=mock_exm,
|
||||
)
|
||||
|
||||
stm_result = await contextual_memory._afetch_stm_context("query")
|
||||
ltm_result = await contextual_memory._afetch_ltm_context("task")
|
||||
em_result = await contextual_memory._afetch_entity_context("query")
|
||||
exm_result = await contextual_memory._afetch_external_context("query")
|
||||
|
||||
assert stm_result == ""
|
||||
assert ltm_result is None
|
||||
assert em_result == ""
|
||||
assert exm_result == ""
|
||||
@@ -1,422 +0,0 @@
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from mem0.memory.main import Memory
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew, Process
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_event_handlers():
|
||||
"""Cleanup event handlers before and after each test"""
|
||||
# Cleanup before test
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup after test
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory():
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
return mock_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_configure_mem0(mock_mem0_memory):
|
||||
with patch(
|
||||
"crewai.memory.external.external_memory.ExternalMemory._configure_mem0",
|
||||
return_value=mock_mem0_memory,
|
||||
) as mocked:
|
||||
yield mocked
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def external_memory_with_mocked_config(patch_configure_mem0):
|
||||
embedder_config = {"provider": "mem0"}
|
||||
external_memory = ExternalMemory(embedder_config=embedder_config)
|
||||
return external_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def crew_with_external_memory(external_memory_with_mocked_config, patch_configure_mem0):
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
tools=[],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
verbose=True,
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
external_memory=external_memory_with_mocked_config,
|
||||
)
|
||||
|
||||
return crew
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def crew_with_external_memory_without_memory_flag(
|
||||
external_memory_with_mocked_config, patch_configure_mem0
|
||||
):
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
tools=[],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
verbose=True,
|
||||
process=Process.sequential,
|
||||
external_memory=external_memory_with_mocked_config,
|
||||
)
|
||||
|
||||
return crew
|
||||
|
||||
|
||||
def test_external_memory_initialization(external_memory_with_mocked_config):
|
||||
assert external_memory_with_mocked_config is not None
|
||||
assert isinstance(external_memory_with_mocked_config, ExternalMemory)
|
||||
|
||||
|
||||
def test_external_memory_save(external_memory_with_mocked_config):
|
||||
memory_item = ExternalMemoryItem(
|
||||
value="test value", metadata={"task": "test_task"}, agent="test_agent"
|
||||
)
|
||||
|
||||
with patch.object(ExternalMemory, "save") as mock_save:
|
||||
external_memory_with_mocked_config.save(
|
||||
value=memory_item.value,
|
||||
metadata=memory_item.metadata,
|
||||
agent=memory_item.agent,
|
||||
)
|
||||
|
||||
mock_save.assert_called_once_with(
|
||||
value=memory_item.value,
|
||||
metadata=memory_item.metadata,
|
||||
agent=memory_item.agent,
|
||||
)
|
||||
|
||||
|
||||
def test_external_memory_reset(external_memory_with_mocked_config):
|
||||
with patch(
|
||||
"crewai.memory.external.external_memory.ExternalMemory.reset"
|
||||
) as mock_reset:
|
||||
external_memory_with_mocked_config.reset()
|
||||
mock_reset.assert_called_once()
|
||||
|
||||
|
||||
def test_external_memory_supported_storages():
|
||||
supported_storages = ExternalMemory.external_supported_storages()
|
||||
assert "mem0" in supported_storages
|
||||
assert callable(supported_storages["mem0"])
|
||||
|
||||
|
||||
def test_external_memory_create_storage_invalid_provider():
|
||||
embedder_config = {"provider": "invalid_provider", "config": {}}
|
||||
|
||||
with pytest.raises(ValueError, match="Provider invalid_provider not supported"):
|
||||
ExternalMemory.create_storage(None, embedder_config)
|
||||
|
||||
|
||||
def test_external_memory_create_storage_missing_provider():
|
||||
embedder_config = {"config": {}}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="embedder_config must include a 'provider' key"
|
||||
):
|
||||
ExternalMemory.create_storage(None, embedder_config)
|
||||
|
||||
|
||||
def test_external_memory_create_storage_missing_config():
|
||||
with pytest.raises(ValueError, match="embedder_config is required"):
|
||||
ExternalMemory.create_storage(None, None)
|
||||
|
||||
|
||||
def test_crew_with_external_memory_initialization(crew_with_external_memory):
|
||||
assert crew_with_external_memory._external_memory is not None
|
||||
assert isinstance(crew_with_external_memory._external_memory, ExternalMemory)
|
||||
assert crew_with_external_memory._external_memory.crew == crew_with_external_memory
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mem_type", ["external", "all"])
|
||||
def test_crew_external_memory_reset(mem_type, crew_with_external_memory):
|
||||
with patch(
|
||||
"crewai.memory.external.external_memory.ExternalMemory.reset"
|
||||
) as mock_reset:
|
||||
crew_with_external_memory.reset_memories(mem_type)
|
||||
mock_reset.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mem_method", ["search", "save"])
|
||||
@pytest.mark.vcr()
|
||||
def test_crew_external_memory_save_with_memory_flag(
|
||||
mem_method, crew_with_external_memory
|
||||
):
|
||||
with patch(
|
||||
f"crewai.memory.external.external_memory.ExternalMemory.{mem_method}"
|
||||
) as mock_method:
|
||||
crew_with_external_memory.kickoff()
|
||||
assert mock_method.call_count > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mem_method", ["search", "save"])
|
||||
@pytest.mark.vcr()
|
||||
def test_crew_external_memory_save_using_crew_without_memory_flag(
|
||||
mem_method, crew_with_external_memory_without_memory_flag
|
||||
):
|
||||
with patch(
|
||||
f"crewai.memory.external.external_memory.ExternalMemory.{mem_method}"
|
||||
) as mock_method:
|
||||
crew_with_external_memory_without_memory_flag.kickoff()
|
||||
assert mock_method.call_count > 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_storage():
|
||||
class CustomStorage(Storage):
|
||||
def __init__(self):
|
||||
self.memories = []
|
||||
|
||||
def save(self, value, metadata=None, agent=None):
|
||||
self.memories.append({"value": value, "metadata": metadata, "agent": agent})
|
||||
|
||||
def search(self, query, limit=10, score_threshold=0.5):
|
||||
return self.memories
|
||||
|
||||
def reset(self):
|
||||
self.memories = []
|
||||
|
||||
custom_storage = CustomStorage()
|
||||
return custom_storage
|
||||
|
||||
|
||||
def test_external_memory_custom_storage(custom_storage, crew_with_external_memory):
|
||||
external_memory = ExternalMemory(storage=custom_storage)
|
||||
|
||||
# by ensuring the crew is set, we can test that the storage is used
|
||||
external_memory.set_crew(crew_with_external_memory)
|
||||
|
||||
test_value = "test value"
|
||||
test_metadata = {"source": "test"}
|
||||
external_memory.save(value=test_value, metadata=test_metadata)
|
||||
|
||||
results = external_memory.search("test")
|
||||
assert len(results) == 1
|
||||
assert results[0]["value"] == test_value
|
||||
assert results[0]["metadata"] == test_metadata
|
||||
|
||||
external_memory.reset()
|
||||
results = external_memory.search("test")
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_external_memory_search_events(
|
||||
custom_storage, external_memory_with_mocked_config
|
||||
):
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
external_memory_with_mocked_config.storage = custom_storage
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
with condition:
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
with condition:
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
external_memory_with_mocked_config.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
and len(events["MemoryQueryCompletedEvent"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
assert success, "Timeout waiting for search events"
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
|
||||
assert dict(events["MemoryQueryStartedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_query_started",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "external_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": None,
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": None,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"query": "test value",
|
||||
"limit": 3,
|
||||
"score_threshold": 0.35,
|
||||
}
|
||||
|
||||
assert dict(events["MemoryQueryCompletedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_query_completed",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "external_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": None,
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": ANY,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"query": "test value",
|
||||
"results": [],
|
||||
"limit": 3,
|
||||
"score_threshold": 0.35,
|
||||
"query_time_ms": ANY,
|
||||
}
|
||||
|
||||
|
||||
def test_external_memory_save_events(
|
||||
custom_storage, external_memory_with_mocked_config
|
||||
):
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
external_memory_with_mocked_config.storage = custom_storage
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
external_memory_with_mocked_config.save(
|
||||
value="saving value",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
assert success, "Timeout waiting for save events"
|
||||
assert len(events["MemorySaveStartedEvent"]) == 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 1
|
||||
|
||||
assert dict(events["MemorySaveStartedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_save_started",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "external_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": None,
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": None,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"value": "saving value",
|
||||
"metadata": {"task": "test_task"},
|
||||
}
|
||||
|
||||
assert dict(events["MemorySaveCompletedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_save_completed",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "external_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": None,
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": ANY,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"value": "saving value",
|
||||
"metadata": {"task": "test_task"},
|
||||
"save_time_ms": ANY,
|
||||
}
|
||||
@@ -1,207 +0,0 @@
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory():
|
||||
"""Fixture to create a LongTermMemory instance"""
|
||||
return LongTermMemory()
|
||||
|
||||
|
||||
def test_long_term_memory_save_events(long_term_memory):
|
||||
events = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
memory = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.5,
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for save events"
|
||||
assert len(events["MemorySaveStartedEvent"]) == 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 1
|
||||
assert len(events["MemorySaveFailedEvent"]) == 0
|
||||
|
||||
assert dict(events["MemorySaveStartedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_save_started",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "long_term_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": "test_agent",
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": None,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"value": "test_task",
|
||||
"metadata": {"task": "test_task", "quality": 0.5},
|
||||
}
|
||||
assert dict(events["MemorySaveCompletedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_save_completed",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "long_term_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": "test_agent",
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": None,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"value": "test_task",
|
||||
"metadata": {
|
||||
"task": "test_task",
|
||||
"quality": 0.5,
|
||||
"agent": "test_agent",
|
||||
"expected_output": "test_output",
|
||||
},
|
||||
"save_time_ms": ANY,
|
||||
}
|
||||
|
||||
|
||||
def test_long_term_memory_search_events(long_term_memory):
|
||||
events = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
with condition:
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
with condition:
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
test_query = "test query"
|
||||
|
||||
long_term_memory.search(test_query, latest_n=5)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
and len(events["MemoryQueryCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for search events"
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
assert len(events["MemoryQueryFailedEvent"]) == 0
|
||||
|
||||
assert dict(events["MemoryQueryStartedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_query_started",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "long_term_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": None,
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": None,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"query": "test query",
|
||||
"limit": 5,
|
||||
"score_threshold": None,
|
||||
}
|
||||
|
||||
assert dict(events["MemoryQueryCompletedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_query_completed",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "long_term_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": None,
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": ANY,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"query": "test query",
|
||||
"results": None,
|
||||
"limit": 5,
|
||||
"score_threshold": None,
|
||||
"query_time_ms": ANY,
|
||||
}
|
||||
|
||||
|
||||
def test_save_and_search(long_term_memory):
|
||||
memory = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.5,
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
find = long_term_memory.search("test_task", latest_n=5)[0]
|
||||
assert find["score"] == 0.5
|
||||
assert find["datetime"] == "test_datetime"
|
||||
assert find["metadata"]["agent"] == "test_agent"
|
||||
assert find["metadata"]["quality"] == 0.5
|
||||
assert find["metadata"]["task"] == "test_task"
|
||||
assert find["metadata"]["expected_output"] == "test_output"
|
||||
@@ -1,231 +0,0 @@
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import pytest
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_term_memory():
|
||||
"""Fixture to create a ShortTermMemory instance"""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
tools=[],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=agent,
|
||||
)
|
||||
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
|
||||
|
||||
|
||||
def test_short_term_memory_search_events(short_term_memory):
|
||||
events = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
with patch.object(short_term_memory.storage, "search", return_value=[]):
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
short_term_memory.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert search_started.wait(timeout=2), (
|
||||
"Timeout waiting for search started event"
|
||||
)
|
||||
assert search_completed.wait(timeout=2), (
|
||||
"Timeout waiting for search completed event"
|
||||
)
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
|
||||
assert dict(events["MemoryQueryStartedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_query_started",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "short_term_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": None,
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": None,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"query": "test value",
|
||||
"limit": 3,
|
||||
"score_threshold": 0.35,
|
||||
}
|
||||
|
||||
assert dict(events["MemoryQueryCompletedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_query_completed",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "short_term_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": None,
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": None,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"query": "test value",
|
||||
"results": [],
|
||||
"limit": 3,
|
||||
"score_threshold": 0.35,
|
||||
"query_time_ms": ANY,
|
||||
}
|
||||
|
||||
|
||||
def test_short_term_memory_save_events(short_term_memory):
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
short_term_memory.save(
|
||||
value="test value",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) == 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 1
|
||||
|
||||
assert dict(events["MemorySaveStartedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_save_started",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "short_term_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": None,
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": None,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"value": "test value",
|
||||
"metadata": {"task": "test_task"},
|
||||
}
|
||||
|
||||
assert dict(events["MemorySaveCompletedEvent"][0]) == {
|
||||
"timestamp": ANY,
|
||||
"type": "memory_save_completed",
|
||||
"source_fingerprint": None,
|
||||
"source_type": "short_term_memory",
|
||||
"fingerprint_metadata": None,
|
||||
"task_id": None,
|
||||
"task_name": None,
|
||||
"from_task": None,
|
||||
"from_agent": None,
|
||||
"agent_role": None,
|
||||
"agent_id": None,
|
||||
"event_id": ANY,
|
||||
"parent_event_id": None,
|
||||
"previous_event_id": ANY,
|
||||
"triggered_by_event_id": None,
|
||||
"started_event_id": ANY,
|
||||
"emission_sequence": ANY,
|
||||
"value": "test value",
|
||||
"metadata": {"task": "test_task"},
|
||||
"save_time_ms": ANY,
|
||||
}
|
||||
|
||||
|
||||
def test_save_and_search(short_term_memory):
|
||||
memory = ShortTermMemoryItem(
|
||||
data="""test value test value test value test value test value test value
|
||||
test value test value test value test value test value test value
|
||||
test value test value test value test value test value test value""",
|
||||
agent="test_agent",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
|
||||
with patch.object(ShortTermMemory, "save") as mock_save:
|
||||
short_term_memory.save(
|
||||
value=memory.data,
|
||||
metadata=memory.metadata,
|
||||
agent=memory.agent,
|
||||
)
|
||||
|
||||
mock_save.assert_called_once_with(
|
||||
value=memory.data,
|
||||
metadata=memory.metadata,
|
||||
agent=memory.agent,
|
||||
)
|
||||
|
||||
expected_result = [
|
||||
{
|
||||
"content": memory.data,
|
||||
"metadata": {"agent": "test_agent"},
|
||||
"score": 0.95,
|
||||
}
|
||||
]
|
||||
with patch.object(ShortTermMemory, "search", return_value=expected_result):
|
||||
find = short_term_memory.search("test value", score_threshold=0.01)[0]
|
||||
assert find["content"] == memory.data, "Data value mismatch."
|
||||
assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch."
|
||||
622
lib/crewai/tests/memory/test_unified_memory.py
Normal file
622
lib/crewai/tests/memory/test_unified_memory.py
Normal file
@@ -0,0 +1,622 @@
|
||||
"""Tests for unified memory: types, storage, Memory, MemoryScope, MemorySlice, Flow integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.memory.types import (
|
||||
MemoryConfig,
|
||||
MemoryMatch,
|
||||
MemoryRecord,
|
||||
ScopeInfo,
|
||||
compute_composite_score,
|
||||
)
|
||||
|
||||
|
||||
# --- Types ---
|
||||
|
||||
|
||||
def test_memory_record_defaults() -> None:
|
||||
r = MemoryRecord(content="hello")
|
||||
assert r.content == "hello"
|
||||
assert r.scope == "/"
|
||||
assert r.categories == []
|
||||
assert r.importance == 0.5
|
||||
assert r.embedding is None
|
||||
assert r.id is not None
|
||||
assert isinstance(r.created_at, datetime)
|
||||
|
||||
|
||||
def test_memory_match() -> None:
|
||||
r = MemoryRecord(content="x", scope="/a")
|
||||
m = MemoryMatch(record=r, score=0.9, match_reasons=["semantic"])
|
||||
assert m.record.content == "x"
|
||||
assert m.score == 0.9
|
||||
assert m.match_reasons == ["semantic"]
|
||||
|
||||
|
||||
def test_scope_info() -> None:
|
||||
i = ScopeInfo(path="/", record_count=5, categories=["c1"], child_scopes=["/a"])
|
||||
assert i.path == "/"
|
||||
assert i.record_count == 5
|
||||
assert i.categories == ["c1"]
|
||||
assert i.child_scopes == ["/a"]
|
||||
|
||||
|
||||
def test_memory_config() -> None:
|
||||
c = MemoryConfig()
|
||||
assert c.recency_weight == 0.3
|
||||
assert c.semantic_weight == 0.5
|
||||
assert c.importance_weight == 0.2
|
||||
|
||||
|
||||
# --- LanceDB storage ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lancedb_path(tmp_path: Path) -> Path:
|
||||
return tmp_path / "mem"
|
||||
|
||||
|
||||
def test_lancedb_save_search(lancedb_path: Path) -> None:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
storage = LanceDBStorage(path=str(lancedb_path), vector_dim=4)
|
||||
r = MemoryRecord(
|
||||
content="test content",
|
||||
scope="/foo",
|
||||
categories=["cat1"],
|
||||
importance=0.8,
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
storage.save([r])
|
||||
results = storage.search(
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
scope_prefix="/foo",
|
||||
limit=5,
|
||||
)
|
||||
assert len(results) == 1
|
||||
rec, score = results[0]
|
||||
assert rec.content == "test content"
|
||||
assert rec.scope == "/foo"
|
||||
assert score >= 0.0
|
||||
|
||||
|
||||
def test_lancedb_delete_count(lancedb_path: Path) -> None:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
storage = LanceDBStorage(path=str(lancedb_path), vector_dim=4)
|
||||
r = MemoryRecord(content="x", scope="/", embedding=[0.0] * 4)
|
||||
storage.save([r])
|
||||
assert storage.count() == 1
|
||||
n = storage.delete(scope_prefix="/")
|
||||
assert n >= 1
|
||||
assert storage.count() == 0
|
||||
|
||||
|
||||
def test_lancedb_list_scopes_get_scope_info(lancedb_path: Path) -> None:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
storage = LanceDBStorage(path=str(lancedb_path), vector_dim=4)
|
||||
storage.save([
|
||||
MemoryRecord(content="a", scope="/", embedding=[0.0] * 4),
|
||||
MemoryRecord(content="b", scope="/team", embedding=[0.0] * 4),
|
||||
])
|
||||
scopes = storage.list_scopes("/")
|
||||
info = storage.get_scope_info("/")
|
||||
assert info.record_count >= 1
|
||||
assert info.path == "/"
|
||||
|
||||
|
||||
# --- Memory class (with mock embedder, no LLM for explicit remember) ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedder() -> MagicMock:
|
||||
m = MagicMock()
|
||||
m.return_value = [[0.1] * 1536] # Chroma-style returns list of lists
|
||||
return m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_with_storage(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
import os
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
# Avoid default embedder that needs API
|
||||
pass
|
||||
|
||||
|
||||
def test_memory_remember_recall_shallow(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
m = Memory(
|
||||
storage=str(tmp_path / "db"),
|
||||
llm=MagicMock(),
|
||||
embedder=mock_embedder,
|
||||
)
|
||||
# Explicit scope/categories/importance so no LLM analysis
|
||||
r = m.remember(
|
||||
"We decided to use Python.",
|
||||
scope="/project",
|
||||
categories=["decision"],
|
||||
importance=0.7,
|
||||
)
|
||||
assert r.content == "We decided to use Python."
|
||||
assert r.scope == "/project"
|
||||
|
||||
matches = m.recall("Python decision", scope="/project", limit=5, depth="shallow")
|
||||
assert len(matches) >= 1
|
||||
assert "Python" in matches[0].record.content or "python" in matches[0].record.content.lower()
|
||||
|
||||
|
||||
def test_memory_forget(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
m = Memory(storage=str(tmp_path / "db2"), llm=MagicMock(), embedder=mock_embedder)
|
||||
m.remember("To forget", scope="/x", categories=[], importance=0.5, metadata={})
|
||||
assert m._storage.count("/x") >= 1
|
||||
n = m.forget(scope="/x")
|
||||
assert n >= 1
|
||||
assert m._storage.count("/x") == 0
|
||||
|
||||
|
||||
def test_memory_scope_slice(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
mem = Memory(storage=str(tmp_path / "db3"), llm=MagicMock(), embedder=mock_embedder)
|
||||
sc = mem.scope("/agent/1")
|
||||
assert sc._root in ("/agent/1", "/agent/1/")
|
||||
sl = mem.slice(["/a", "/b"], read_only=True)
|
||||
assert sl._read_only is True
|
||||
assert "/a" in sl._scopes and "/b" in sl._scopes
|
||||
|
||||
|
||||
def test_memory_list_scopes_info_tree(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
m = Memory(storage=str(tmp_path / "db4"), llm=MagicMock(), embedder=mock_embedder)
|
||||
m.remember("Root", scope="/", categories=[], importance=0.5, metadata={})
|
||||
scopes = m.list_scopes("/")
|
||||
info = m.info("/")
|
||||
assert info.record_count >= 1
|
||||
tree = m.tree("/", max_depth=2)
|
||||
assert "/" in tree or "0 records" in tree or "1 records" in tree
|
||||
|
||||
|
||||
# --- MemoryScope ---
|
||||
|
||||
|
||||
def test_memory_scope_remember_recall(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.memory.memory_scope import MemoryScope
|
||||
|
||||
mem = Memory(storage=str(tmp_path / "db5"), llm=MagicMock(), embedder=mock_embedder)
|
||||
scope = MemoryScope(mem, "/crew/1")
|
||||
scope.remember("Scoped note", scope="/", categories=[], importance=0.5, metadata={})
|
||||
results = scope.recall("note", limit=5, depth="shallow")
|
||||
assert len(results) >= 1
|
||||
|
||||
|
||||
# --- MemorySlice recall (read-only) ---
|
||||
|
||||
|
||||
def test_memory_slice_recall(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.memory.memory_scope import MemorySlice
|
||||
|
||||
mem = Memory(storage=str(tmp_path / "db6"), llm=MagicMock(), embedder=mock_embedder)
|
||||
mem.remember("In scope A", scope="/a", categories=[], importance=0.5, metadata={})
|
||||
sl = MemorySlice(mem, ["/a"], read_only=True)
|
||||
matches = sl.recall("scope", limit=5, depth="shallow")
|
||||
assert isinstance(matches, list)
|
||||
|
||||
|
||||
def test_memory_slice_remember_raises_when_read_only(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.memory.memory_scope import MemorySlice
|
||||
|
||||
mem = Memory(storage=str(tmp_path / "db7"), llm=MagicMock(), embedder=mock_embedder)
|
||||
sl = MemorySlice(mem, ["/a"], read_only=True)
|
||||
with pytest.raises(PermissionError):
|
||||
sl.remember("x", scope="/a")
|
||||
|
||||
|
||||
# --- Flow memory ---
|
||||
|
||||
|
||||
def test_flow_has_default_memory() -> None:
|
||||
"""Flow auto-creates a Memory instance when none is provided."""
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
class DefaultFlow(Flow):
|
||||
pass
|
||||
|
||||
f = DefaultFlow()
|
||||
assert f.memory is not None
|
||||
assert isinstance(f.memory, Memory)
|
||||
|
||||
|
||||
def test_flow_recall_remember_raise_when_memory_explicitly_none() -> None:
|
||||
"""Flow raises ValueError when memory is explicitly set to None."""
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
class NoMemoryFlow(Flow):
|
||||
memory = None
|
||||
|
||||
f = NoMemoryFlow()
|
||||
# Explicitly set to None after __init__ auto-creates
|
||||
f.memory = None
|
||||
with pytest.raises(ValueError, match="No memory configured"):
|
||||
f.recall("query")
|
||||
with pytest.raises(ValueError, match="No memory configured"):
|
||||
f.remember("content")
|
||||
|
||||
|
||||
def test_flow_recall_remember_with_memory(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
mem = Memory(storage=str(tmp_path / "flow_db"), llm=MagicMock(), embedder=mock_embedder)
|
||||
|
||||
class FlowWithMemory(Flow):
|
||||
memory = mem
|
||||
|
||||
f = FlowWithMemory()
|
||||
f.remember("Flow remembered this", scope="/flow", categories=[], importance=0.6, metadata={})
|
||||
results = f.recall("remembered", limit=5, depth="shallow")
|
||||
assert len(results) >= 1
|
||||
|
||||
|
||||
# --- extract_memories ---
|
||||
|
||||
|
||||
def test_memory_extract_memories_returns_list_from_llm(tmp_path: Path) -> None:
|
||||
"""Memory.extract_memories() delegates to LLM and returns list of strings."""
|
||||
from crewai.memory.analyze import ExtractedMemories
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.supports_function_calling.return_value = True
|
||||
mock_llm.call.return_value = ExtractedMemories(
|
||||
memories=["We use Python for the backend.", "API rate limit is 100/min."]
|
||||
)
|
||||
|
||||
mem = Memory(
|
||||
storage=str(tmp_path / "extract_db"),
|
||||
llm=mock_llm,
|
||||
embedder=MagicMock(return_value=[[0.1] * 1536]),
|
||||
)
|
||||
result = mem.extract_memories("Task: Build API. Result: We used Python and set rate limit 100/min.")
|
||||
assert result == ["We use Python for the backend.", "API rate limit is 100/min."]
|
||||
mock_llm.call.assert_called_once()
|
||||
call_kw = mock_llm.call.call_args[1]
|
||||
assert call_kw.get("response_model") == ExtractedMemories
|
||||
|
||||
|
||||
def test_memory_extract_memories_empty_content_returns_empty_list(tmp_path: Path) -> None:
|
||||
"""Memory.extract_memories() with empty/whitespace content returns [] without calling LLM."""
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mem = Memory(storage=str(tmp_path / "empty_db"), llm=mock_llm, embedder=MagicMock())
|
||||
assert mem.extract_memories("") == []
|
||||
assert mem.extract_memories(" \n ") == []
|
||||
mock_llm.call.assert_not_called()
|
||||
|
||||
|
||||
def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
|
||||
"""_save_to_memory calls memory.extract_memories(raw) then memory.remember(m) for each."""
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
from crewai.agents.parser import AgentFinish
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory.extract_memories.return_value = ["Fact A.", "Fact B."]
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.memory = mock_memory
|
||||
mock_agent._logger = MagicMock()
|
||||
mock_agent.role = "Researcher"
|
||||
|
||||
mock_task = MagicMock()
|
||||
mock_task.description = "Do research"
|
||||
mock_task.expected_output = "A report"
|
||||
|
||||
class MinimalExecutor(CrewAgentExecutorMixin):
|
||||
crew = None
|
||||
agent = mock_agent
|
||||
task = mock_task
|
||||
iterations = 0
|
||||
max_iter = 1
|
||||
messages = []
|
||||
_i18n = MagicMock()
|
||||
_printer = Printer()
|
||||
|
||||
executor = MinimalExecutor()
|
||||
executor._save_to_memory(
|
||||
AgentFinish(thought="", output="We found X and Y.", text="We found X and Y.")
|
||||
)
|
||||
|
||||
raw_expected = "Task: Do research\nAgent: Researcher\nExpected result: A report\nResult: We found X and Y."
|
||||
mock_memory.extract_memories.assert_called_once_with(raw_expected)
|
||||
assert mock_memory.remember.call_count == 2
|
||||
calls = [mock_memory.remember.call_args_list[i][0][0] for i in range(2)]
|
||||
assert calls == ["Fact A.", "Fact B."]
|
||||
|
||||
|
||||
def test_executor_save_to_memory_skips_delegation_output() -> None:
|
||||
"""_save_to_memory does nothing when output contains delegate action."""
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
from crewai.agents.parser import AgentFinish
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.memory = mock_memory
|
||||
mock_agent._logger = MagicMock()
|
||||
mock_task = MagicMock(description="Task", expected_output="Out")
|
||||
|
||||
class MinimalExecutor(CrewAgentExecutorMixin):
|
||||
crew = None
|
||||
agent = mock_agent
|
||||
task = mock_task
|
||||
iterations = 0
|
||||
max_iter = 1
|
||||
messages = []
|
||||
_i18n = MagicMock()
|
||||
_printer = Printer()
|
||||
|
||||
delegate_text = f"Action: {sanitize_tool_name('Delegate work to coworker')}"
|
||||
full_text = delegate_text + " rest"
|
||||
executor = MinimalExecutor()
|
||||
executor._save_to_memory(
|
||||
AgentFinish(thought="", output=full_text, text=full_text)
|
||||
)
|
||||
|
||||
mock_memory.extract_memories.assert_not_called()
|
||||
mock_memory.remember.assert_not_called()
|
||||
|
||||
|
||||
def test_memory_scope_extract_memories_delegates() -> None:
|
||||
"""MemoryScope.extract_memories delegates to underlying Memory."""
|
||||
from crewai.memory.memory_scope import MemoryScope
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory.extract_memories.return_value = ["Scoped fact."]
|
||||
scope = MemoryScope(mock_memory, "/agent/1")
|
||||
result = scope.extract_memories("Some content")
|
||||
mock_memory.extract_memories.assert_called_once_with("Some content")
|
||||
assert result == ["Scoped fact."]
|
||||
|
||||
|
||||
def test_memory_slice_extract_memories_delegates() -> None:
|
||||
"""MemorySlice.extract_memories delegates to underlying Memory."""
|
||||
from crewai.memory.memory_scope import MemorySlice
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory.extract_memories.return_value = ["Sliced fact."]
|
||||
sl = MemorySlice(mock_memory, ["/a", "/b"], read_only=True)
|
||||
result = sl.extract_memories("Some content")
|
||||
mock_memory.extract_memories.assert_called_once_with("Some content")
|
||||
assert result == ["Sliced fact."]
|
||||
|
||||
|
||||
def test_flow_extract_memories_raises_when_memory_explicitly_none() -> None:
|
||||
"""Flow.extract_memories raises ValueError when memory is explicitly set to None."""
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
f = Flow()
|
||||
f.memory = None
|
||||
with pytest.raises(ValueError, match="No memory configured"):
|
||||
f.extract_memories("some content")
|
||||
|
||||
|
||||
def test_flow_extract_memories_delegates_when_memory_present() -> None:
|
||||
"""Flow.extract_memories delegates to flow memory and returns list."""
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory.extract_memories.return_value = ["Flow fact 1.", "Flow fact 2."]
|
||||
|
||||
class FlowWithMemory(Flow):
|
||||
memory = mock_memory
|
||||
|
||||
f = FlowWithMemory()
|
||||
result = f.extract_memories("content here")
|
||||
mock_memory.extract_memories.assert_called_once_with("content here")
|
||||
assert result == ["Flow fact 1.", "Flow fact 2."]
|
||||
|
||||
|
||||
# --- Composite scoring ---
|
||||
|
||||
|
||||
def test_composite_score_brand_new_memory() -> None:
|
||||
"""Brand-new memory has decay ~ 1.0; composite = 0.5*0.8 + 0.3*1.0 + 0.2*0.7 = 0.84."""
|
||||
config = MemoryConfig()
|
||||
record = MemoryRecord(
|
||||
content="test",
|
||||
scope="/",
|
||||
importance=0.7,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
score, reasons = compute_composite_score(record, 0.8, config)
|
||||
assert 0.82 <= score <= 0.86
|
||||
assert "semantic" in reasons
|
||||
assert "recency" in reasons
|
||||
assert "importance" in reasons
|
||||
|
||||
|
||||
def test_composite_score_old_memory_decayed() -> None:
|
||||
"""Memory 60 days old (2 half-lives) has decay = 0.25; composite ~ 0.575."""
|
||||
config = MemoryConfig(recency_half_life_days=30)
|
||||
old_date = datetime.utcnow() - timedelta(days=60)
|
||||
record = MemoryRecord(
|
||||
content="old",
|
||||
scope="/",
|
||||
importance=0.5,
|
||||
created_at=old_date,
|
||||
)
|
||||
score, reasons = compute_composite_score(record, 0.8, config)
|
||||
assert 0.55 <= score <= 0.60
|
||||
assert "semantic" in reasons
|
||||
assert "recency" not in reasons # decay 0.25 is not > 0.5
|
||||
|
||||
|
||||
def test_composite_score_reranks_results(
|
||||
tmp_path: Path, mock_embedder: MagicMock
|
||||
) -> None:
|
||||
"""Same semantic score: high-importance recent memory ranks first."""
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
# Use same dim as default LanceDB (1536) so storage does not overwrite embedding
|
||||
emb = [0.1] * 1536
|
||||
mem = Memory(
|
||||
storage=str(tmp_path / "rerank_db"),
|
||||
llm=MagicMock(),
|
||||
embedder=MagicMock(return_value=[emb]),
|
||||
)
|
||||
# Same embedding so same semantic score from storage
|
||||
mem.remember(
|
||||
"Important decision",
|
||||
scope="/",
|
||||
categories=[],
|
||||
importance=1.0,
|
||||
metadata={},
|
||||
)
|
||||
old = datetime.utcnow() - timedelta(days=90)
|
||||
record_low = MemoryRecord(
|
||||
content="Old trivial note",
|
||||
scope="/",
|
||||
importance=0.1,
|
||||
created_at=old,
|
||||
embedding=emb,
|
||||
)
|
||||
mem._storage.save([record_low])
|
||||
|
||||
matches = mem.recall("decision", scope="/", limit=5, depth="shallow")
|
||||
assert len(matches) >= 2
|
||||
# Top result should be the high-importance recent one (stored via remember)
|
||||
assert "Important" in matches[0].record.content or "important" in matches[0].record.content.lower()
|
||||
|
||||
|
||||
def test_composite_score_match_reasons_populated() -> None:
|
||||
"""match_reasons includes recency for fresh, importance for high-importance; omits for old/low."""
|
||||
config = MemoryConfig()
|
||||
fresh_high = MemoryRecord(
|
||||
content="x",
|
||||
importance=0.9,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
score1, reasons1 = compute_composite_score(fresh_high, 0.5, config)
|
||||
assert "semantic" in reasons1
|
||||
assert "recency" in reasons1
|
||||
assert "importance" in reasons1
|
||||
|
||||
old_low = MemoryRecord(
|
||||
content="y",
|
||||
importance=0.1,
|
||||
created_at=datetime.utcnow() - timedelta(days=60),
|
||||
)
|
||||
score2, reasons2 = compute_composite_score(old_low, 0.5, config)
|
||||
assert "semantic" in reasons2
|
||||
assert "recency" not in reasons2
|
||||
assert "importance" not in reasons2
|
||||
|
||||
|
||||
def test_composite_score_custom_config() -> None:
|
||||
"""Zero recency/importance weights => composite equals semantic score."""
|
||||
config = MemoryConfig(
|
||||
recency_weight=0.0,
|
||||
semantic_weight=1.0,
|
||||
importance_weight=0.0,
|
||||
)
|
||||
record = MemoryRecord(
|
||||
content="any",
|
||||
importance=0.9,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
score, reasons = compute_composite_score(record, 0.73, config)
|
||||
assert score == pytest.approx(0.73, rel=1e-5)
|
||||
assert "semantic" in reasons
|
||||
|
||||
|
||||
# --- LLM fallback ---
|
||||
|
||||
|
||||
def test_analyze_for_save_llm_failure_returns_defaults() -> None:
|
||||
"""When LLM raises, analyze_for_save returns safe defaults."""
|
||||
from crewai.memory.analyze import MemoryAnalysis, analyze_for_save
|
||||
|
||||
llm = MagicMock()
|
||||
llm.call.side_effect = RuntimeError("API rate limit")
|
||||
result = analyze_for_save(
|
||||
"some content",
|
||||
existing_scopes=["/", "/project"],
|
||||
existing_categories=["cat1"],
|
||||
llm=llm,
|
||||
)
|
||||
assert isinstance(result, MemoryAnalysis)
|
||||
assert result.suggested_scope == "/"
|
||||
assert result.categories == []
|
||||
assert result.importance == 0.5
|
||||
assert result.extracted_metadata.entities == []
|
||||
assert result.extracted_metadata.dates == []
|
||||
assert result.extracted_metadata.topics == []
|
||||
|
||||
|
||||
def test_extract_memories_llm_failure_returns_raw() -> None:
|
||||
"""When LLM raises, extract_memories_from_content returns [content]."""
|
||||
from crewai.memory.analyze import extract_memories_from_content
|
||||
|
||||
llm = MagicMock()
|
||||
llm.call.side_effect = RuntimeError("Network error")
|
||||
content = "Task result: We chose PostgreSQL."
|
||||
result = extract_memories_from_content(content, llm)
|
||||
assert result == [content]
|
||||
|
||||
|
||||
def test_analyze_query_llm_failure_returns_defaults() -> None:
|
||||
"""When LLM raises, analyze_query returns safe defaults with available scopes."""
|
||||
from crewai.memory.analyze import QueryAnalysis, analyze_query
|
||||
|
||||
llm = MagicMock()
|
||||
llm.call.side_effect = RuntimeError("Timeout")
|
||||
result = analyze_query(
|
||||
"what did we decide?",
|
||||
available_scopes=["/", "/project", "/team", "/company", "/other", "/extra"],
|
||||
scope_info=None,
|
||||
llm=llm,
|
||||
)
|
||||
assert isinstance(result, QueryAnalysis)
|
||||
assert result.keywords == []
|
||||
assert result.time_hints == []
|
||||
assert result.complexity == "simple"
|
||||
assert result.suggested_scopes == ["/", "/project", "/team", "/company", "/other"]
|
||||
|
||||
|
||||
def test_remember_survives_llm_failure(
|
||||
tmp_path: Path, mock_embedder: MagicMock
|
||||
) -> None:
|
||||
"""When analyze_for_save fails (LLM raises), remember() still saves with defaults."""
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
llm = MagicMock()
|
||||
llm.call.side_effect = RuntimeError("LLM unavailable")
|
||||
mem = Memory(
|
||||
storage=str(tmp_path / "fallback_db"),
|
||||
llm=llm,
|
||||
embedder=mock_embedder,
|
||||
)
|
||||
record = mem.remember("We decided to use PostgreSQL.")
|
||||
assert record.content == "We decided to use PostgreSQL."
|
||||
assert record.scope == "/"
|
||||
assert record.categories == []
|
||||
assert record.importance == 0.5
|
||||
assert record.id is not None
|
||||
assert mem._storage.count() == 1
|
||||
@@ -6,7 +6,6 @@ import pytest
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
from crewai.memory.storage.rag_storage import RAGStorage # type: ignore[import-untyped]
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
@@ -67,31 +66,6 @@ def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock)
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_client_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles RAG client failures in memory operations."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.search.side_effect = RuntimeError("ChromaDB server error")
|
||||
|
||||
storage = RAGStorage("short_term", crew=None)
|
||||
|
||||
results = storage.search("test query")
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_rag_storage_save_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles save operation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.add_documents.side_effect = Exception("Failed to add documents")
|
||||
|
||||
storage = RAGStorage("long_term", crew=None)
|
||||
|
||||
storage.save("test memory", {"key": "value"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_reset_readonly_database(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage reset handles readonly database errors."""
|
||||
@@ -120,21 +94,6 @@ def test_knowledge_storage_reset_collection_does_not_exist(
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_reset_failure_propagation(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage reset propagates unexpected errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.delete_collection.side_effect = Exception("Unexpected database error")
|
||||
|
||||
storage = RAGStorage("entities", crew=None)
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="An error occurred while resetting the entities memory"
|
||||
):
|
||||
storage.reset()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_malformed_search_results(mock_get_client: MagicMock) -> None:
|
||||
"""Test KnowledgeStorage handles malformed search results."""
|
||||
@@ -181,20 +140,6 @@ def test_knowledge_storage_network_interruption(mock_get_client: MagicMock) -> N
|
||||
assert second_attempt[0]["content"] == "recovered result"
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.get_rag_client")
|
||||
def test_memory_storage_collection_creation_failure(mock_get_client: MagicMock) -> None:
|
||||
"""Test RAGStorage handles collection creation failures."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.get_or_create_collection.side_effect = Exception(
|
||||
"Failed to create collection"
|
||||
)
|
||||
|
||||
storage = RAGStorage("user_memory", crew=None)
|
||||
|
||||
storage.save("test data", {"metadata": "test"})
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_knowledge_storage_embedding_dimension_mismatch_detailed(
|
||||
mock_get_client: MagicMock,
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Tests for RAGStorage custom path functionality."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_custom_path(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses custom path when provided."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
custom_path = "/custom/memory/path"
|
||||
embedder_config = {"provider": "openai", "config": {"model": "text-embedding-3-small"}}
|
||||
|
||||
RAGStorage(
|
||||
type="short_term",
|
||||
crew=None,
|
||||
path=custom_path,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
config_arg = mock_create_client.call_args[0][0]
|
||||
assert config_arg.settings.persist_directory == custom_path
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_default_path_when_none(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses default path when no custom path is provided."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
embedder_config = {"provider": "openai", "config": {"model": "text-embedding-3-small"}}
|
||||
|
||||
storage = RAGStorage(
|
||||
type="short_term",
|
||||
crew=None,
|
||||
path=None,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
assert storage.path is None
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_custom_path_with_batch_size(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses custom path with batch_size in config."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
custom_path = "/custom/batch/path"
|
||||
embedder_config = {
|
||||
"provider": "openai",
|
||||
"config": {"model": "text-embedding-3-small", "batch_size": 100},
|
||||
}
|
||||
|
||||
RAGStorage(
|
||||
type="long_term",
|
||||
crew=None,
|
||||
path=custom_path,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
config_arg = mock_create_client.call_args[0][0]
|
||||
assert config_arg.settings.persist_directory == custom_path
|
||||
assert config_arg.batch_size == 100
|
||||
@@ -1,504 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
from mem0 import Memory, MemoryClient
|
||||
|
||||
|
||||
# Define the class (if not already defined)
|
||||
class MockCrew:
|
||||
def __init__(self):
|
||||
self.agents = [MagicMock(role="Test Agent")]
|
||||
|
||||
|
||||
# Test data constants
|
||||
SYSTEM_CONTENT = (
|
||||
"You are Friendly chatbot assistant. You are a kind and "
|
||||
"knowledgeable chatbot assistant. You excel at understanding user needs, "
|
||||
"providing helpful responses, and maintaining engaging conversations. "
|
||||
"You remember previous interactions to provide a personalized experience.\n"
|
||||
"Your personal goal is: Engage in useful and interesting conversations "
|
||||
"with users while remembering context.\n"
|
||||
"To give my best complete final answer to the task respond using the exact "
|
||||
"following format:\n\n"
|
||||
"Thought: I now can give a great answer\n"
|
||||
"Final Answer: Your final answer must be the great and the most complete "
|
||||
"as possible, it must be outcome described.\n\n"
|
||||
"I MUST use these formats, my job depends on it!"
|
||||
)
|
||||
|
||||
USER_CONTENT = (
|
||||
"\nCurrent Task: Respond to user conversation. User message: "
|
||||
"What do you know about me?\n\n"
|
||||
"This is the expected criteria for your final answer: Contextually "
|
||||
"appropriate, helpful, and friendly response.\n"
|
||||
"you MUST return the actual complete content as the final answer, "
|
||||
"not a summary.\n\n"
|
||||
"# Useful context: \nExternal memories:\n"
|
||||
"- User is from India\n"
|
||||
"- User is interested in the solar system\n"
|
||||
"- User name is Vidit Ostwal\n"
|
||||
"- User is interested in French cuisine\n\n"
|
||||
"Begin! This is VERY important to you, use the tools available and give "
|
||||
"your best Final Answer, your job depends on it!\n\n"
|
||||
"Thought:"
|
||||
)
|
||||
|
||||
ASSISTANT_CONTENT = (
|
||||
"I now can give a great answer \n"
|
||||
"Final Answer: Hi Vidit! From our previous conversations, I know you're "
|
||||
"from India and have a great interest in the solar system. It's fascinating "
|
||||
"to explore the wonders of space, isn't it? Also, I remember you have a "
|
||||
"passion for French cuisine, which has so many delightful dishes to explore. "
|
||||
"If there's anything specific you'd like to discuss or learn about—whether "
|
||||
"it's about the solar system or some great French recipes—feel free to let "
|
||||
"me know! I'm here to help."
|
||||
)
|
||||
|
||||
TEST_DESCRIPTION = (
|
||||
"Respond to user conversation. User message: What do you know about me?"
|
||||
)
|
||||
|
||||
# Extracted content (after processing by _get_user_message and _get_assistant_message)
|
||||
EXTRACTED_USER_CONTENT = "What do you know about me?"
|
||||
EXTRACTED_ASSISTANT_CONTENT = (
|
||||
"Hi Vidit! From our previous conversations, I know you're "
|
||||
"from India and have a great interest in the solar system. It's fascinating "
|
||||
"to explore the wonders of space, isn't it? Also, I remember you have a "
|
||||
"passion for French cuisine, which has so many delightful dishes to explore. "
|
||||
"If there's anything specific you'd like to discuss or learn about—whether "
|
||||
"it's about the solar system or some great French recipes—feel free to let "
|
||||
"me know! I'm here to help."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory():
|
||||
"""Fixture to create a mock Memory instance"""
|
||||
return MagicMock(spec=Memory)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mem0_storage_with_mocked_config(mock_mem0_memory):
|
||||
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
|
||||
|
||||
# Patch the Memory class to return our mock
|
||||
with patch(
|
||||
"mem0.Memory.from_config", return_value=mock_mem0_memory
|
||||
) as mock_from_config:
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "mock_vector_store",
|
||||
"config": {"host": "localhost", "port": 6333},
|
||||
},
|
||||
"llm": {
|
||||
"provider": "mock_llm",
|
||||
"config": {"api_key": "mock-api-key", "model": "mock-model"},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "mock_embedder",
|
||||
"config": {"api_key": "mock-api-key", "model": "mock-model"},
|
||||
},
|
||||
"graph_store": {
|
||||
"provider": "mock_graph_store",
|
||||
"config": {
|
||||
"url": "mock-url",
|
||||
"username": "mock-user",
|
||||
"password": "mock-password",
|
||||
},
|
||||
},
|
||||
"history_db_path": "/mock/path",
|
||||
"version": "test-version",
|
||||
"custom_fact_extraction_prompt": "mock prompt 1",
|
||||
"custom_update_memory_prompt": "mock prompt 2",
|
||||
}
|
||||
|
||||
# Parameters like run_id, includes, and excludes doesn't matter in Memory OSS
|
||||
crew = MockCrew()
|
||||
|
||||
embedder_config = {
|
||||
"user_id": "test_user",
|
||||
"local_mem0_config": config,
|
||||
"run_id": "my_run_id",
|
||||
"includes": "include1",
|
||||
"excludes": "exclude1",
|
||||
"infer": True,
|
||||
}
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=embedder_config)
|
||||
return mem0_storage, mock_from_config, config
|
||||
|
||||
|
||||
def test_mem0_storage_initialization(mem0_storage_with_mocked_config, mock_mem0_memory):
|
||||
"""Test that Mem0Storage initializes correctly with the mocked config"""
|
||||
mem0_storage, mock_from_config, config = mem0_storage_with_mocked_config
|
||||
assert mem0_storage.memory_type == "short_term"
|
||||
assert mem0_storage.memory is mock_mem0_memory
|
||||
mock_from_config.assert_called_once_with(config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory_client():
|
||||
"""Fixture to create a mock MemoryClient instance"""
|
||||
return MagicMock(spec=MemoryClient)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_client):
|
||||
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
|
||||
|
||||
# We need to patch the MemoryClient before it's instantiated
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
crew = MockCrew()
|
||||
|
||||
embedder_config = {
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"run_id": "my_run_id",
|
||||
"includes": "include1",
|
||||
"excludes": "exclude1",
|
||||
"infer": True,
|
||||
}
|
||||
|
||||
return Mem0Storage(type="short_term", crew=crew, config=embedder_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mem0_storage_with_memory_client_using_explictly_config(
|
||||
mock_mem0_memory_client, mock_mem0_memory
|
||||
):
|
||||
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
|
||||
|
||||
# We need to patch both MemoryClient and Memory to prevent actual initialization
|
||||
with (
|
||||
patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client),
|
||||
patch.object(Memory, "__new__", return_value=mock_mem0_memory),
|
||||
):
|
||||
crew = MockCrew()
|
||||
new_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}}
|
||||
|
||||
return Mem0Storage(type="short_term", crew=crew, config=new_config)
|
||||
|
||||
|
||||
def test_mem0_storage_with_memory_client_initialization(
|
||||
mem0_storage_with_memory_client_using_config_from_crew, mock_mem0_memory_client
|
||||
):
|
||||
"""Test Mem0Storage initialization with MemoryClient"""
|
||||
assert (
|
||||
mem0_storage_with_memory_client_using_config_from_crew.memory_type
|
||||
== "short_term"
|
||||
)
|
||||
assert (
|
||||
mem0_storage_with_memory_client_using_config_from_crew.memory
|
||||
is mock_mem0_memory_client
|
||||
)
|
||||
|
||||
|
||||
def test_mem0_storage_with_explict_config(
|
||||
mem0_storage_with_memory_client_using_explictly_config,
|
||||
):
|
||||
expected_config = {"provider": "mem0", "config": {"api_key": "new-api-key"}}
|
||||
assert (
|
||||
mem0_storage_with_memory_client_using_explictly_config.config == expected_config
|
||||
)
|
||||
|
||||
|
||||
def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_client):
|
||||
mock_mem0_memory_client.update_project = MagicMock()
|
||||
|
||||
new_categories = [
|
||||
{
|
||||
"lifestyle_management_concerns": (
|
||||
"Tracks daily routines, habits, hobbies and interests "
|
||||
"including cooking, time management and work-life balance"
|
||||
)
|
||||
},
|
||||
]
|
||||
|
||||
crew = MockCrew()
|
||||
|
||||
config = {
|
||||
"user_id": "test_user",
|
||||
"api_key": "ABCDEFGH",
|
||||
"org_id": "my_org_id",
|
||||
"project_id": "my_project_id",
|
||||
"custom_categories": new_categories,
|
||||
}
|
||||
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
_ = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
|
||||
mock_mem0_memory_client.update_project.assert_called_once_with(
|
||||
custom_categories=new_categories
|
||||
)
|
||||
|
||||
|
||||
def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
"""Test save method for different memory types"""
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
# Test short_term memory type (already set in fixture)
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
infer=True,
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
run_id="my_run_id",
|
||||
user_id="test_user",
|
||||
agent_id="Test_Agent",
|
||||
)
|
||||
|
||||
|
||||
def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config):
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
mem0_storage.crew.agents = [
|
||||
MagicMock(role="Test Agent"),
|
||||
MagicMock(role="Test Agent 2"),
|
||||
MagicMock(role="Test Agent 3"),
|
||||
]
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
infer=True,
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
run_id="my_run_id",
|
||||
user_id="test_user",
|
||||
agent_id="Test_Agent_Test_Agent_2_Test_Agent_3",
|
||||
)
|
||||
|
||||
|
||||
def test_save_method_with_memory_client(
|
||||
mem0_storage_with_memory_client_using_config_from_crew,
|
||||
):
|
||||
"""Test save method for different memory types"""
|
||||
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
|
||||
mem0_storage.memory.add = MagicMock()
|
||||
|
||||
# Test short_term memory type (already set in fixture)
|
||||
test_value = "This is a test memory"
|
||||
test_metadata = {
|
||||
"description": TEST_DESCRIPTION,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_CONTENT},
|
||||
{"role": "user", "content": USER_CONTENT},
|
||||
{"role": "assistant", "content": ASSISTANT_CONTENT},
|
||||
],
|
||||
"agent": "Friendly chatbot assistant",
|
||||
}
|
||||
|
||||
mem0_storage.save(test_value, test_metadata)
|
||||
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[
|
||||
{"role": "user", "content": EXTRACTED_USER_CONTENT},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": EXTRACTED_ASSISTANT_CONTENT,
|
||||
},
|
||||
],
|
||||
infer=True,
|
||||
metadata={
|
||||
"type": "short_term",
|
||||
"description": TEST_DESCRIPTION,
|
||||
"agent": "Friendly chatbot assistant",
|
||||
},
|
||||
version="v2",
|
||||
run_id="my_run_id",
|
||||
includes="include1",
|
||||
excludes="exclude1",
|
||||
output_format="v1.1",
|
||||
user_id="test_user",
|
||||
agent_id="Test_Agent",
|
||||
)
|
||||
|
||||
|
||||
def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
|
||||
"""Test search method for different memory types"""
|
||||
mem0_storage, _, _ = mem0_storage_with_mocked_config
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
user_id="test_user",
|
||||
filters={"AND": [{"run_id": "my_run_id"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
|
||||
def test_search_method_with_memory_client(
|
||||
mem0_storage_with_memory_client_using_config_from_crew,
|
||||
):
|
||||
"""Test search method for different memory types"""
|
||||
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
metadata={"type": "short_term"},
|
||||
user_id="test_user",
|
||||
version="v2",
|
||||
run_id="my_run_id",
|
||||
output_format="v1.1",
|
||||
filters={"AND": [{"run_id": "my_run_id"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
|
||||
def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
|
||||
"""Test that Mem0Storage sets infer=True by default for short_term memory."""
|
||||
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
|
||||
crew = MockCrew()
|
||||
|
||||
config = {"user_id": "test_user", "api_key": "ABCDEFGH"}
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
assert mem0_storage.infer is True
|
||||
|
||||
|
||||
def test_save_memory_using_agent_entity(mock_mem0_memory_client):
|
||||
config = {
|
||||
"agent_id": "agent-123",
|
||||
}
|
||||
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(type="external", config=config)
|
||||
mem0_storage.save("test memory", {"key": "value"})
|
||||
mem0_storage.memory.add.assert_called_once_with(
|
||||
[{"role": "assistant", "content": "test memory"}],
|
||||
infer=True,
|
||||
metadata={"type": "external", "key": "value"},
|
||||
agent_id="agent-123",
|
||||
)
|
||||
|
||||
|
||||
def test_search_method_with_agent_entity():
|
||||
config = {
|
||||
"agent_id": "agent-123",
|
||||
}
|
||||
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(type="external", config=config)
|
||||
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
filters={"AND": [{"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content"] == "Result 1"
|
||||
|
||||
|
||||
def test_search_method_with_agent_id_and_user_id():
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_results = {
|
||||
"results": [
|
||||
{"score": 0.9, "memory": "Result 1"},
|
||||
{"score": 0.4, "memory": "Result 2"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(
|
||||
type="external", config={"agent_id": "agent-123", "user_id": "user-123"}
|
||||
)
|
||||
|
||||
mem0_storage.memory.search = MagicMock(return_value=mock_results)
|
||||
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
|
||||
|
||||
mem0_storage.memory.search.assert_called_once_with(
|
||||
query="test query",
|
||||
limit=5,
|
||||
user_id="user-123",
|
||||
filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]},
|
||||
threshold=0.5,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content"] == "Result 1"
|
||||
@@ -36,10 +36,7 @@ from crewai.flow import Flow, start
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.process import Process
|
||||
from crewai.project import CrewBase, agent, before_kickoff, crew, task
|
||||
from crewai.task import Task
|
||||
@@ -2425,7 +2422,8 @@ def test_multiple_conditional_tasks(researcher, writer):
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_using_contextual_memory():
|
||||
def test_using_memory():
|
||||
"""With memory=True, crew has _memory and kickoff runs successfully."""
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -2445,11 +2443,8 @@ def test_using_contextual_memory():
|
||||
memory=True,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
ContextualMemory, "build_context_for_task", return_value=""
|
||||
) as contextual_mem:
|
||||
crew.kickoff()
|
||||
contextual_mem.assert_called_once()
|
||||
crew.kickoff()
|
||||
assert crew._memory is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@@ -2529,28 +2524,32 @@ def test_memory_events_are_emitted():
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: (
|
||||
len(events["MemorySaveStartedEvent"]) >= 3
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 3
|
||||
and len(events["MemoryQueryStartedEvent"]) >= 3
|
||||
and len(events["MemoryQueryCompletedEvent"]) >= 3
|
||||
len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and (
|
||||
len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
or len(events["MemorySaveFailedEvent"]) >= 1
|
||||
)
|
||||
and len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
and len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
and len(events["MemoryRetrievalCompletedEvent"]) >= 1
|
||||
),
|
||||
timeout=10,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
assert success, f"Timeout waiting for memory events. Got: {dict(events)}"
|
||||
assert len(events["MemorySaveStartedEvent"]) == 3
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 3
|
||||
assert len(events["MemorySaveFailedEvent"]) == 0
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 3
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 3
|
||||
assert len(events["MemoryQueryFailedEvent"]) == 0
|
||||
assert len(events["MemoryRetrievalStartedEvent"]) == 1
|
||||
assert len(events["MemoryRetrievalCompletedEvent"]) == 1
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1, (
|
||||
f"Expected at least one MemorySaveCompletedEvent; got MemorySaveFailedEvent: {events.get('MemorySaveFailedEvent')}"
|
||||
)
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert len(events["MemoryRetrievalStartedEvent"]) >= 1
|
||||
assert len(events["MemoryRetrievalCompletedEvent"]) >= 1
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_using_contextual_memory_with_long_term_memory():
|
||||
def test_using_memory_with_remember():
|
||||
"""With memory=True, crew uses unified memory and kickoff runs successfully."""
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -2567,19 +2566,16 @@ def test_using_contextual_memory_with_long_term_memory():
|
||||
crew = Crew(
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
long_term_memory=LongTermMemory(),
|
||||
memory=True,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
ContextualMemory, "build_context_for_task", return_value=""
|
||||
) as contextual_mem:
|
||||
crew.kickoff()
|
||||
contextual_mem.assert_called_once()
|
||||
assert crew.memory is False
|
||||
crew.kickoff()
|
||||
assert crew._memory is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_warning_long_term_memory_without_entity_memory():
|
||||
def test_memory_enabled_creates_unified_memory():
|
||||
"""With unified memory, memory=True creates _memory and kickoff runs."""
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -2597,55 +2593,16 @@ def test_warning_long_term_memory_without_entity_memory():
|
||||
crew = Crew(
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
long_term_memory=LongTermMemory(),
|
||||
memory=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("crewai.utilities.printer.Printer.print") as mock_print,
|
||||
patch(
|
||||
"crewai.memory.long_term.long_term_memory.LongTermMemory.save"
|
||||
) as save_memory,
|
||||
):
|
||||
crew.kickoff()
|
||||
mock_print.assert_called_with(
|
||||
content="Long term memory is enabled, but entity memory is not enabled. Please configure entity memory or set memory=True to automatically enable it.",
|
||||
color="bold_yellow",
|
||||
)
|
||||
save_memory.assert_not_called()
|
||||
crew.kickoff()
|
||||
assert crew._memory is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_long_term_memory_with_memory_flag():
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
backstory="You're an expert in research and you love to learn new things.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="Research a topic to teach a kid aged 6 about math.",
|
||||
expected_output="A topic, explanation, angle, and examples.",
|
||||
agent=math_researcher,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("crewai.utilities.printer.Printer.print") as mock_print,
|
||||
patch("crewai.memory.long_term.long_term_memory.LongTermMemory.save") as save_memory,
|
||||
):
|
||||
crew = Crew(
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
memory=True,
|
||||
long_term_memory=LongTermMemory(),
|
||||
)
|
||||
crew.kickoff()
|
||||
mock_print.assert_not_called()
|
||||
save_memory.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_using_contextual_memory_with_short_term_memory():
|
||||
def test_memory_remember_called_after_task():
|
||||
"""With memory=True, extract_memories is called with raw content and remember is called per extracted item."""
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -2662,19 +2619,58 @@ def test_using_contextual_memory_with_short_term_memory():
|
||||
crew = Crew(
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
short_term_memory=ShortTermMemory(),
|
||||
memory=True,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
ContextualMemory, "build_context_for_task", return_value=""
|
||||
) as contextual_mem:
|
||||
crew._memory, "extract_memories", wraps=crew._memory.extract_memories
|
||||
) as extract_mock, patch.object(
|
||||
crew._memory, "remember", wraps=crew._memory.remember
|
||||
) as remember_mock:
|
||||
crew.kickoff()
|
||||
contextual_mem.assert_called_once()
|
||||
assert crew.memory is False
|
||||
|
||||
# extract_memories should be called with the raw content blob
|
||||
extract_mock.assert_called()
|
||||
raw = extract_mock.call_args.args[0]
|
||||
assert "Task:" in raw
|
||||
assert "Agent:" in raw or "Researcher" in raw
|
||||
|
||||
# remember should be called once per extracted memory (may be 0 if LLM returned none)
|
||||
if remember_mock.called:
|
||||
for call in remember_mock.call_args_list:
|
||||
content = call.args[0] if call.args else call.kwargs.get("content", "")
|
||||
assert isinstance(content, str) and len(content) > 0
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_disabled_memory_using_contextual_memory():
|
||||
def test_using_memory_recall_and_save():
|
||||
"""With memory=True, crew uses unified memory for recall and save."""
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
backstory="You're an expert in research and you love to learn new things.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="Research a topic to teach a kid aged 6 about math.",
|
||||
expected_output="A topic, explanation, angle, and examples.",
|
||||
agent=math_researcher,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
memory=True,
|
||||
)
|
||||
|
||||
crew.kickoff()
|
||||
assert crew._memory is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_disabled_memory():
|
||||
"""With memory=False, crew has no _memory and kickoff runs without memory."""
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -2694,11 +2690,8 @@ def test_disabled_memory_using_contextual_memory():
|
||||
memory=False,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
ContextualMemory, "build_context_for_task", return_value=""
|
||||
) as contextual_mem:
|
||||
crew.kickoff()
|
||||
contextual_mem.assert_not_called()
|
||||
crew.kickoff()
|
||||
assert getattr(crew, "_memory", None) is None
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@@ -4446,68 +4439,21 @@ def test_crew_kickoff_for_each_works_with_manager_agent_copy():
|
||||
|
||||
|
||||
def test_crew_copy_with_memory():
|
||||
"""Test that copying a crew with memory enabled does not raise validation errors and copies memory correctly."""
|
||||
"""Test that copying a crew with memory enabled does not raise and shares the same memory instance."""
|
||||
agent = Agent(role="Test Agent", goal="Test Goal", backstory="Test Backstory")
|
||||
task = Task(description="Test Task", expected_output="Test Output", agent=agent)
|
||||
crew = Crew(agents=[agent], tasks=[task], memory=True)
|
||||
|
||||
original_short_term_id = (
|
||||
id(crew._short_term_memory) if crew._short_term_memory else None
|
||||
)
|
||||
original_long_term_id = (
|
||||
id(crew._long_term_memory) if crew._long_term_memory else None
|
||||
)
|
||||
original_entity_id = id(crew._entity_memory) if crew._entity_memory else None
|
||||
original_external_id = id(crew._external_memory) if crew._external_memory else None
|
||||
assert crew._memory is not None, "Crew with memory=True should have _memory"
|
||||
|
||||
try:
|
||||
crew_copy = crew.copy()
|
||||
|
||||
assert hasattr(crew_copy, "_short_term_memory"), (
|
||||
"Copied crew should have _short_term_memory"
|
||||
assert hasattr(crew_copy, "_memory"), "Copied crew should have _memory"
|
||||
assert crew_copy._memory is not None, "Copied _memory should not be None"
|
||||
assert crew_copy._memory is crew._memory, (
|
||||
"Copy passes memory=self._memory so clone shares the same memory"
|
||||
)
|
||||
assert crew_copy._short_term_memory is not None, (
|
||||
"Copied _short_term_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._short_term_memory) != original_short_term_id, (
|
||||
"Copied _short_term_memory should be a new object"
|
||||
)
|
||||
|
||||
assert hasattr(crew_copy, "_long_term_memory"), (
|
||||
"Copied crew should have _long_term_memory"
|
||||
)
|
||||
assert crew_copy._long_term_memory is not None, (
|
||||
"Copied _long_term_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._long_term_memory) != original_long_term_id, (
|
||||
"Copied _long_term_memory should be a new object"
|
||||
)
|
||||
|
||||
assert hasattr(crew_copy, "_entity_memory"), (
|
||||
"Copied crew should have _entity_memory"
|
||||
)
|
||||
assert crew_copy._entity_memory is not None, (
|
||||
"Copied _entity_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._entity_memory) != original_entity_id, (
|
||||
"Copied _entity_memory should be a new object"
|
||||
)
|
||||
|
||||
if original_external_id:
|
||||
assert hasattr(crew_copy, "_external_memory"), (
|
||||
"Copied crew should have _external_memory"
|
||||
)
|
||||
assert crew_copy._external_memory is not None, (
|
||||
"Copied _external_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._external_memory) != original_external_id, (
|
||||
"Copied _external_memory should be a new object"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
not hasattr(crew_copy, "_external_memory")
|
||||
or crew_copy._external_memory is None
|
||||
), "Copied _external_memory should be None if not originally present"
|
||||
|
||||
except pydantic_core.ValidationError as e:
|
||||
if "Input should be an instance of" in str(e) and ("Memory" in str(e)):
|
||||
@@ -4515,7 +4461,7 @@ def test_crew_copy_with_memory():
|
||||
f"Copying with memory raised Pydantic ValidationError, likely due to incorrect memory copy: {e}"
|
||||
)
|
||||
else:
|
||||
raise e # Re-raise other validation errors
|
||||
raise e
|
||||
except Exception as e:
|
||||
pytest.fail(f"Copying crew raised an unexpected exception: {e}")
|
||||
|
||||
@@ -4807,9 +4753,8 @@ def test_default_crew_name(researcher, writer):
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_ensure_exchanged_messages_are_propagated_to_external_memory():
|
||||
external_memory = ExternalMemory(storage=MagicMock())
|
||||
|
||||
def test_memory_remember_receives_task_content():
|
||||
"""With memory=True, extract_memories receives raw content with task, agent, expected output, and result."""
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
@@ -4826,33 +4771,21 @@ def test_ensure_exchanged_messages_are_propagated_to_external_memory():
|
||||
crew = Crew(
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
external_memory=external_memory,
|
||||
memory=True,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
ExternalMemory, "save", return_value=None
|
||||
) as external_memory_save:
|
||||
crew._memory, "extract_memories", wraps=crew._memory.extract_memories
|
||||
) as extract_mock:
|
||||
crew.kickoff()
|
||||
|
||||
external_memory_save.assert_called_once()
|
||||
extract_mock.assert_called()
|
||||
raw = extract_mock.call_args.args[0]
|
||||
|
||||
call_args = external_memory_save.call_args
|
||||
|
||||
assert "value" in call_args.kwargs or len(call_args.args) > 0
|
||||
assert "metadata" in call_args.kwargs or len(call_args.args) > 1
|
||||
|
||||
if "metadata" in call_args.kwargs:
|
||||
metadata = call_args.kwargs["metadata"]
|
||||
else:
|
||||
metadata = call_args.args[1]
|
||||
|
||||
assert "description" in metadata
|
||||
assert "messages" in metadata
|
||||
assert isinstance(metadata["messages"], list)
|
||||
assert len(metadata["messages"]) >= 2
|
||||
|
||||
messages = metadata["messages"]
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Researcher" in messages[0]["content"]
|
||||
assert messages[1]["role"] == "user"
|
||||
assert "Research a topic to teach a kid aged 6 about math" in messages[1]["content"]
|
||||
# The raw content passed to extract_memories should contain the task context
|
||||
assert "Task:" in raw
|
||||
assert "Research" in raw or "topic" in raw
|
||||
assert "Agent:" in raw
|
||||
assert "Researcher" in raw
|
||||
assert "Expected result:" in raw
|
||||
assert "Result:" in raw
|
||||
|
||||
@@ -15,7 +15,7 @@ dependencies = [
|
||||
"openai~=1.83.0",
|
||||
"python-dotenv~=1.1.1",
|
||||
"pygithub~=1.59.1",
|
||||
"rich~=13.9.4",
|
||||
"rich>=13.9.4",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -142,6 +142,14 @@ python_files = "test_*.py"
|
||||
python_classes = "Test*"
|
||||
python_functions = "test_*"
|
||||
|
||||
[tool.uv]
|
||||
|
||||
# composio-core pins rich<14 but textual requires rich>=14.
|
||||
# onnxruntime 1.24+ dropped Python 3.10 wheels; cap it so qdrant[fastembed] resolves on 3.10.
|
||||
override-dependencies = [
|
||||
"rich>=13.7.1",
|
||||
"onnxruntime<1.24; python_version < '3.11'",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = [
|
||||
|
||||
Reference in New Issue
Block a user