From 8d8772d607305843a04d62381d1c9d0c4baed028 Mon Sep 17 00:00:00 2001 From: Greyson Lalonde Date: Wed, 22 Oct 2025 12:52:28 -0400 Subject: [PATCH] chore: cleanup memo0, add typing --- .../src/crewai/memory/storage/mem0_storage.py | 210 ++++++++++-------- lib/crewai/src/crewai/utilities/types.py | 5 +- 2 files changed, 125 insertions(+), 90 deletions(-) diff --git a/lib/crewai/src/crewai/memory/storage/mem0_storage.py b/lib/crewai/src/crewai/memory/storage/mem0_storage.py index 73820ab11..90445d98b 100644 --- a/lib/crewai/src/crewai/memory/storage/mem0_storage.py +++ b/lib/crewai/src/crewai/memory/storage/mem0_storage.py @@ -1,16 +1,83 @@ -from collections import defaultdict +from __future__ import annotations + from collections.abc import Iterable import os import re -from typing import Any +from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict -from mem0 import Memory, MemoryClient # type: ignore[import-untyped,import-not-found] +from mem0 import Memory, MemoryClient # type: ignore[import-untyped] from crewai.memory.storage.interface import Storage from crewai.rag.chromadb.utils import _sanitize_collection_name -MAX_AGENT_ID_LENGTH_MEM0 = 255 +if TYPE_CHECKING: + from crewai.crew import Crew + from crewai.utilities.types import LLMMessage, MessageRole + + +MAX_AGENT_ID_LENGTH_MEM0: Final[int] = 255 +_ASSISTANT_MESSAGE_MARKER: Final[str] = "Final Answer:" +_USER_MESSAGE_PATTERN: Final[re.Pattern[str]] = re.compile(r"User message:\s*(.*)") + + +class BaseMetadata(TypedDict): + short_term: Literal["short_term"] + long_term: Literal["long_term"] + entities: Literal["entity"] + external: Literal["external"] + + +BASE_METADATA: Final[BaseMetadata] = { + "short_term": "short_term", + "long_term": "long_term", + "entities": "entity", + "external": "external", +} + +MEMORY_TYPE_MAP: Final[dict[str, dict[str, str]]] = { + "short_term": {"type": "short_term"}, + "long_term": {"type": "long_term"}, + "entities": {"type": "entity"}, + "external": {"type": "external"}, +} + + +class BaseParams(TypedDict, total=False): + """Parameters for Mem0 memory operations.""" + + metadata: dict[str, Any] + infer: bool + includes: Any + excludes: Any + output_format: str + version: str + run_id: str + user_id: str + agent_id: str + + +class Mem0Config(TypedDict, total=False): + """Configuration for Mem0Storage.""" + + run_id: str + includes: Any + excludes: Any + custom_categories: Any + infer: bool + api_key: str + org_id: str + project_id: str + local_mem0_config: Any + user_id: str + agent_id: str + + +class Mem0Filter(TypedDict, total=False): + """Filter dictionary for Mem0 search operations.""" + + AND: list[dict[str, Any]] + OR: list[dict[str, Any]] class Mem0Storage(Storage): @@ -18,33 +85,22 @@ class Mem0Storage(Storage): Extends Storage to handle embedding and searching across entities using Mem0. """ - def __init__(self, type, crew=None, config=None): - super().__init__() - - self._validate_type(type) + def __init__( + self, + type: Literal["short_term", "long_term", "entities", "external"], + crew: Crew | None = None, + config: Mem0Config | None = None, + ) -> None: self.memory_type = type self.crew = crew - self.config = config or {} - - self._extract_config_values() - self._initialize_memory() - - def _validate_type(self, type): - supported_types = {"short_term", "long_term", "entities", "external"} - if type not in supported_types: - raise ValueError( - f"Invalid type '{type}' for Mem0Storage. " - f"Must be one of: {', '.join(supported_types)}" - ) - - def _extract_config_values(self): - self.mem0_run_id = self.config.get("run_id") - self.includes = self.config.get("includes") - self.excludes = self.config.get("excludes") - self.custom_categories = self.config.get("custom_categories") - self.infer = self.config.get("infer", True) - - def _initialize_memory(self): + if config is None: + config = {} + self.config: Mem0Config = config + self.mem0_run_id = config.get("run_id") + self.includes = config.get("includes") + self.excludes = config.get("excludes") + self.custom_categories = config.get("custom_categories") + self.infer = config.get("infer", True) api_key = self.config.get("api_key") or os.getenv("MEM0_API_KEY") org_id = self.config.get("org_id") project_id = self.config.get("project_id") @@ -65,47 +121,39 @@ class Mem0Storage(Storage): else Memory() ) - def _create_filter_for_search(self): - """ + def _create_filter_for_search(self) -> Mem0Filter: + """Create filter dictionary for search operations. + Returns: - dict: A filter dictionary containing AND conditions for querying data. - - Includes user_id and agent_id if both are present. - - Includes user_id if only user_id is present. - - Includes agent_id if only agent_id is present. - - Includes run_id if memory_type is 'short_term' and - mem0_run_id is present. + Filter dictionary containing AND/OR conditions for querying data. """ - filter = defaultdict(list) - if self.memory_type == "short_term" and self.mem0_run_id: - filter["AND"].append({"run_id": self.mem0_run_id}) - else: - user_id = self.config.get("user_id", "") - agent_id = self.config.get("agent_id", "") + return {"AND": [{"run_id": self.mem0_run_id}]} - if user_id and agent_id: - filter["OR"].append({"user_id": user_id}) - filter["OR"].append({"agent_id": agent_id}) - elif user_id: - filter["AND"].append({"user_id": user_id}) - elif agent_id: - filter["AND"].append({"agent_id": agent_id}) - - return filter + user_id = self.config.get("user_id") + agent_id = self.config.get("agent_id") + if user_id and agent_id: + return {"OR": [{"user_id": user_id}, {"agent_id": agent_id}]} + if user_id: + return {"AND": [{"user_id": user_id}]} + if agent_id: + return {"AND": [{"agent_id": agent_id}]} + return {} def save(self, value: Any, metadata: dict[str, Any]) -> None: - def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str: - return next( + def _last_content(messages_: Iterable[LLMMessage], role: MessageRole) -> str: + content = next( ( m.get("content", "") - for m in reversed(list(messages)) + for m in reversed(list(messages_)) if m.get("role") == role ), "", ) + return str(content) if content else "" conversations = [] - messages = metadata.pop("messages", None) + messages: Iterable[LLMMessage] = metadata.pop("messages", []) if messages: last_user = _last_content(messages, "user") last_assistant = _last_content(messages, "assistant") @@ -120,20 +168,11 @@ class Mem0Storage(Storage): user_id = self.config.get("user_id", "") - base_metadata = { - "short_term": "short_term", - "long_term": "long_term", - "entities": "entity", - "external": "external", - } - - # Shared base params - params: dict[str, Any] = { - "metadata": {"type": base_metadata[self.memory_type], **metadata}, + params: BaseParams = { + "metadata": {"type": BASE_METADATA[self.memory_type], **metadata}, "infer": self.infer, } - # MemoryClient-specific overrides if isinstance(self.memory, MemoryClient): params["includes"] = self.includes params["excludes"] = self.excludes @@ -154,7 +193,7 @@ class Mem0Storage(Storage): def search( self, query: str, limit: int = 5, score_threshold: float = 0.6 ) -> list[Any]: - params = { + params: dict[str, Any] = { "query": query, "limit": limit, "version": "v2", @@ -164,15 +203,8 @@ class Mem0Storage(Storage): if user_id := self.config.get("user_id", ""): params["user_id"] = user_id - memory_type_map = { - "short_term": {"type": "short_term"}, - "long_term": {"type": "long_term"}, - "entities": {"type": "entity"}, - "external": {"type": "external"}, - } - - if self.memory_type in memory_type_map: - params["metadata"] = memory_type_map[self.memory_type] + if self.memory_type in MEMORY_TYPE_MAP: + params["metadata"] = MEMORY_TYPE_MAP[self.memory_type] if self.memory_type == "short_term": params["run_id"] = self.mem0_run_id @@ -195,11 +227,12 @@ class Mem0Storage(Storage): return [r for r in results["results"]] - def reset(self): + def reset(self) -> None: if self.memory: self.memory.reset() - def _sanitize_role(self, role: str) -> str: + @staticmethod + def _sanitize_role(role: str) -> str: """ Sanitizes agent roles to ensure valid directory names. """ @@ -210,21 +243,20 @@ class Mem0Storage(Storage): return "" agents = self.crew.agents - agents = [self._sanitize_role(agent.role) for agent in agents] - agents = "_".join(agents) + agents_roles = "".join([self._sanitize_role(agent.role) for agent in agents]) return _sanitize_collection_name( - name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0 + name=agents_roles, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0 ) - def _get_assistant_message(self, text: str) -> str: - marker = "Final Answer:" - if marker in text: - return text.split(marker, 1)[1].strip() + @staticmethod + def _get_assistant_message(text: str) -> str: + if _ASSISTANT_MESSAGE_MARKER in text: + return text.split(_ASSISTANT_MESSAGE_MARKER, 1)[1].strip() return text - def _get_user_message(self, text: str) -> str: - pattern = r"User message:\s*(.*)" - match = re.search(pattern, text) + @staticmethod + def _get_user_message(text: str) -> str: + match = _USER_MESSAGE_PATTERN.search(text) if match: return match.group(1).strip() return text diff --git a/lib/crewai/src/crewai/utilities/types.py b/lib/crewai/src/crewai/utilities/types.py index bc331a97e..d5cd832db 100644 --- a/lib/crewai/src/crewai/utilities/types.py +++ b/lib/crewai/src/crewai/utilities/types.py @@ -3,6 +3,9 @@ from typing import Any, Literal, TypedDict +MessageRole = Literal["user", "assistant", "system"] + + class LLMMessage(TypedDict): """Type for formatted LLM messages. @@ -11,5 +14,5 @@ class LLMMessage(TypedDict): instead of str | list[dict[str, str]] """ - role: Literal["user", "assistant", "system"] + role: MessageRole content: str | list[dict[str, Any]]