From dffca0cb8a7d6b8acedd3dc926d56119133f2d6f Mon Sep 17 00:00:00 2001 From: lucasgomide Date: Mon, 31 Mar 2025 13:07:56 -0300 Subject: [PATCH] feat: prepare Mem0Storage to accept config paramenter We're planning to remove `memory_config` soon. This commit kindly prepare this storage to accept the config provided directly --- src/crewai/memory/storage/mem0_storage.py | 43 ++++++++++++----------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index 782b17152..ce526a457 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -11,7 +11,7 @@ class Mem0Storage(Storage): Extends Storage to handle embedding and searching across entities using Mem0. """ - def __init__(self, type, crew=None): + def __init__(self, type, crew=None, config=None): super().__init__() supported_types = ["user", "short_term", "long_term", "entities"] if type not in supported_types: @@ -22,7 +22,9 @@ class Mem0Storage(Storage): self.memory_type = type self.crew = crew - self.memory_config = crew.memory_config + self.config = config or {} + # TODO: Memory config will be removed in the future the config will be passed as a parameter + self.memory_config = self.config or getattr(crew, "memory_config", {}) or {} # User ID is required for user memory type "user" since it's used as a unique identifier for the user. user_id = self._get_user_id() @@ -30,7 +32,7 @@ class Mem0Storage(Storage): raise ValueError("User ID is required for user memory type") # API key in memory config overrides the environment variable - config = self.memory_config.get("config", {}) + config = self._get_config() mem0_api_key = config.get("api_key") or os.getenv("MEM0_API_KEY") mem0_org_id = config.get("org_id") mem0_project_id = config.get("project_id") @@ -59,15 +61,14 @@ class Mem0Storage(Storage): def save(self, value: Any, metadata: Dict[str, Any]) -> None: user_id = self._get_user_id() agent_name = self._get_agent_name() - if self.memory_type == "user": + if user_id: self.memory.add(value, user_id=user_id, metadata={**metadata}) - elif self.memory_type == "short_term": - agent_name = self._get_agent_name() + + if self.memory_type == "short_term": self.memory.add( value, agent_id=agent_name, metadata={"type": "short_term", **metadata} ) elif self.memory_type == "long_term": - agent_name = self._get_agent_name() self.memory.add( value, agent_id=agent_name, @@ -75,9 +76,8 @@ class Mem0Storage(Storage): metadata={"type": "long_term", **metadata}, ) elif self.memory_type == "entities": - entity_name = self._get_agent_name() self.memory.add( - value, user_id=entity_name, metadata={"type": "entity", **metadata} + value, user_id=agent_name, metadata={"type": "entity", **metadata} ) def search( @@ -87,10 +87,10 @@ class Mem0Storage(Storage): score_threshold: float = 0.35, ) -> List[Any]: params = {"query": query, "limit": limit} - if self.memory_type == "user": - user_id = self._get_user_id() + if user_id := self._get_user_id(): params["user_id"] = user_id - elif self.memory_type == "short_term": + + if self.memory_type == "short_term": agent_name = self._get_agent_name() params["agent_id"] = agent_name params["metadata"] = {"type": "short_term"} @@ -108,20 +108,21 @@ class Mem0Storage(Storage): results = self.memory.search(**params) return [r for r in results if r["score"] >= score_threshold] - def _get_user_id(self): - if self.memory_type == "user": - if hasattr(self, "memory_config") and self.memory_config is not None: - return self.memory_config.get("config", {}).get("user_id") - else: - return None - return None + def _get_user_id(self) -> str: + return self._get_config().get("user_id", "") - def _get_agent_name(self): - agents = self.crew.agents if self.crew else [] + def _get_agent_name(self) -> str: + if not self.crew: + return "" + + agents = self.crew.agents agents = [self._sanitize_role(agent.role) for agent in agents] agents = "_".join(agents) return agents + def _get_config(self) -> Dict[str, Any]: + return self.config or getattr(self, "memory_config", {}).get("config", {}) or {} + def reset(self): if self.memory: self.memory.reset()