mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Add support for retrieving user preferences and memories using Mem0 (#1209)
* Integrate Mem0 * Update src/crewai/memory/contextual/contextual_memory.py Co-authored-by: Deshraj Yadav <deshraj@gatech.edu> * pending commit for _fetch_user_memories * update poetry.lock * fixes mypy issues * fix mypy checks * New fixes for user_id * remove memory_provider * handle memory_provider * checks for memory_config * add mem0 to dependency * Update pyproject.toml Co-authored-by: Deshraj Yadav <deshraj@gatech.edu> * update docs * update doc * bump mem0 version * fix api error msg and mypy issue * mypy fix * resolve comments * fix memory usage without mem0 * mem0 version bump * lazy import mem0 --------- Co-authored-by: Deshraj Yadav <deshraj@gatech.edu> Co-authored-by: João Moura <joaomdmoura@gmail.com> Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from .entity.entity_memory import EntityMemory
|
||||
from .long_term.long_term_memory import LongTermMemory
|
||||
from .short_term.short_term_memory import ShortTermMemory
|
||||
from .user.user_memory import UserMemory
|
||||
|
||||
__all__ = ["EntityMemory", "LongTermMemory", "ShortTermMemory"]
|
||||
__all__ = ["UserMemory", "EntityMemory", "LongTermMemory", "ShortTermMemory"]
|
||||
|
||||
@@ -1,13 +1,25 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory
|
||||
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory
|
||||
|
||||
|
||||
class ContextualMemory:
|
||||
def __init__(self, stm: ShortTermMemory, ltm: LongTermMemory, em: EntityMemory):
|
||||
def __init__(
|
||||
self,
|
||||
memory_config: Optional[Dict[str, Any]],
|
||||
stm: ShortTermMemory,
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
um: UserMemory,
|
||||
):
|
||||
if memory_config is not None:
|
||||
self.memory_provider = memory_config.get("provider")
|
||||
else:
|
||||
self.memory_provider = None
|
||||
self.stm = stm
|
||||
self.ltm = ltm
|
||||
self.em = em
|
||||
self.um = um
|
||||
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
"""
|
||||
@@ -23,6 +35,8 @@ class ContextualMemory:
|
||||
context.append(self._fetch_ltm_context(task.description))
|
||||
context.append(self._fetch_stm_context(query))
|
||||
context.append(self._fetch_entity_context(query))
|
||||
if self.memory_provider == "mem0":
|
||||
context.append(self._fetch_user_context(query))
|
||||
return "\n".join(filter(None, context))
|
||||
|
||||
def _fetch_stm_context(self, query) -> str:
|
||||
@@ -32,7 +46,10 @@ class ContextualMemory:
|
||||
"""
|
||||
stm_results = self.stm.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in stm_results]
|
||||
[
|
||||
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
|
||||
for result in stm_results
|
||||
]
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
@@ -62,6 +79,26 @@ class ContextualMemory:
|
||||
"""
|
||||
em_results = self.em.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
[
|
||||
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
|
||||
for result in em_results
|
||||
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
def _fetch_user_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches and formats relevant user information from User Memory.
|
||||
Args:
|
||||
query (str): The search query to find relevant user memories.
|
||||
Returns:
|
||||
str: Formatted user memories as bullet points, or an empty string if none found.
|
||||
"""
|
||||
user_memories = self.um.search(query)
|
||||
if not user_memories:
|
||||
return ""
|
||||
|
||||
formatted_memories = "\n".join(
|
||||
f"- {result['memory']}" for result in user_memories
|
||||
)
|
||||
return f"User memories/preferences:\n{formatted_memories}"
|
||||
|
||||
@@ -11,21 +11,43 @@ class EntityMemory(Memory):
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None):
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
else RAGStorage(
|
||||
type="entities",
|
||||
allow_reset=True,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
self.memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
self.memory_provider = None
|
||||
|
||||
if self.memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
storage = Mem0Storage(type="entities", crew=crew)
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
else RAGStorage(
|
||||
type="entities",
|
||||
allow_reset=False,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
)
|
||||
)
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
if self.memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
super().save(data, item.metadata)
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -23,5 +23,12 @@ class Memory:
|
||||
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
def search(self, query: str) -> List[Dict[str, Any]]:
|
||||
return self.storage.search(query)
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
@@ -14,13 +14,27 @@ class ShortTermMemory(Memory):
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None):
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
else RAGStorage(
|
||||
type="short_term", embedder_config=embedder_config, crew=crew
|
||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
self.memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
self.memory_provider = None
|
||||
|
||||
if self.memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew)
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
if storage
|
||||
else RAGStorage(
|
||||
type="short_term", embedder_config=embedder_config, crew=crew
|
||||
)
|
||||
)
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def save(
|
||||
@@ -30,11 +44,20 @@ class ShortTermMemory(Memory):
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
if self.memory_provider == "mem0":
|
||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||
|
||||
def search(self, query: str, score_threshold: float = 0.35):
|
||||
return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
|
||||
@@ -7,8 +7,10 @@ class Storage:
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def search(self, key: str) -> List[Dict[str, Any]]: # type: ignore
|
||||
pass
|
||||
def search(
|
||||
self, query: str, limit: int, score_threshold: float
|
||||
) -> Dict[str, Any] | List[Any]:
|
||||
return {}
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
104
src/crewai/memory/storage/mem0_storage.py
Normal file
104
src/crewai/memory/storage/mem0_storage.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from mem0 import MemoryClient
|
||||
from crewai.memory.storage.interface import Storage
|
||||
|
||||
|
||||
class Mem0Storage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embedding and searching across entities using Mem0.
|
||||
"""
|
||||
|
||||
def __init__(self, type, crew=None):
|
||||
super().__init__()
|
||||
|
||||
if type not in ["user", "short_term", "long_term", "entities"]:
|
||||
raise ValueError("Invalid type for Mem0Storage. Must be 'user' or 'agent'.")
|
||||
|
||||
self.memory_type = type
|
||||
self.crew = crew
|
||||
self.memory_config = crew.memory_config
|
||||
|
||||
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
|
||||
user_id = self._get_user_id()
|
||||
if type == "user" and not user_id:
|
||||
raise ValueError("User ID is required for user memory type")
|
||||
|
||||
# API key in memory config overrides the environment variable
|
||||
mem0_api_key = self.memory_config.get("config", {}).get("api_key") or os.getenv(
|
||||
"MEM0_API_KEY"
|
||||
)
|
||||
self.memory = MemoryClient(api_key=mem0_api_key)
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
user_id = self._get_user_id()
|
||||
agent_name = self._get_agent_name()
|
||||
if self.memory_type == "user":
|
||||
self.memory.add(value, user_id=user_id, metadata={**metadata})
|
||||
elif self.memory_type == "short_term":
|
||||
agent_name = self._get_agent_name()
|
||||
self.memory.add(
|
||||
value, agent_id=agent_name, metadata={"type": "short_term", **metadata}
|
||||
)
|
||||
elif self.memory_type == "long_term":
|
||||
agent_name = self._get_agent_name()
|
||||
self.memory.add(
|
||||
value,
|
||||
agent_id=agent_name,
|
||||
infer=False,
|
||||
metadata={"type": "long_term", **metadata},
|
||||
)
|
||||
elif self.memory_type == "entities":
|
||||
entity_name = None
|
||||
self.memory.add(
|
||||
value, user_id=entity_name, metadata={"type": "entity", **metadata}
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
params = {"query": query, "limit": limit}
|
||||
if self.memory_type == "user":
|
||||
user_id = self._get_user_id()
|
||||
params["user_id"] = user_id
|
||||
elif self.memory_type == "short_term":
|
||||
agent_name = self._get_agent_name()
|
||||
params["agent_id"] = agent_name
|
||||
params["metadata"] = {"type": "short_term"}
|
||||
elif self.memory_type == "long_term":
|
||||
agent_name = self._get_agent_name()
|
||||
params["agent_id"] = agent_name
|
||||
params["metadata"] = {"type": "long_term"}
|
||||
elif self.memory_type == "entities":
|
||||
agent_name = self._get_agent_name()
|
||||
params["agent_id"] = agent_name
|
||||
params["metadata"] = {"type": "entity"}
|
||||
|
||||
# Discard the filters for now since we create the filters
|
||||
# automatically when the crew is created.
|
||||
results = self.memory.search(**params)
|
||||
return [r for r in results if r["score"] >= score_threshold]
|
||||
|
||||
def _get_user_id(self):
|
||||
if self.memory_type == "user":
|
||||
if hasattr(self, "memory_config") and self.memory_config is not None:
|
||||
return self.memory_config.get("config", {}).get("user_id")
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _get_agent_name(self):
|
||||
agents = self.crew.agents if self.crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return agents
|
||||
0
src/crewai/memory/user/__init__.py
Normal file
0
src/crewai/memory/user/__init__.py
Normal file
45
src/crewai/memory/user/user_memory.py
Normal file
45
src/crewai/memory/user/user_memory.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from crewai.memory.memory import Memory
|
||||
|
||||
|
||||
class UserMemory(Memory):
|
||||
"""
|
||||
UserMemory class for handling user memory storage and retrieval.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None):
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
storage = Mem0Storage(type="user", crew=crew)
|
||||
super().__init__(storage)
|
||||
|
||||
def save(
|
||||
self,
|
||||
value,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
# TODO: Change this function since we want to take care of the case where we save memories for the usr
|
||||
data = f"Remember the details about the user: {value}"
|
||||
super().save(data, metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
results = super().search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
8
src/crewai/memory/user/user_memory_item.py
Normal file
8
src/crewai/memory/user/user_memory_item.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class UserMemoryItem:
|
||||
def __init__(self, data: Any, user: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
self.data = data
|
||||
self.user = user
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
Reference in New Issue
Block a user