revert ltm

This commit is contained in:
Brandon Hancock
2025-02-07 13:57:23 -05:00
parent 99ec3c17b6
commit 85334cb617
2 changed files with 12 additions and 7 deletions

View File

@@ -1,3 +1,7 @@
from typing import Any, Optional
from pydantic import PrivateAttr
from crewai.memory.entity.entity_memory_item import EntityMemoryItem from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage from crewai.memory.storage.rag_storage import RAGStorage
@@ -10,13 +14,15 @@ class EntityMemory(Memory):
Inherits from the Memory class. Inherits from the Memory class.
""" """
_memory_provider: Optional[str] = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None): def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
self.memory_provider = crew.memory_config.get("provider") memory_provider = crew.memory_config.get("provider")
else: else:
self.memory_provider = None memory_provider = None
if self.memory_provider == "mem0": if memory_provider == "mem0":
try: try:
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError: except ImportError:
@@ -36,11 +42,13 @@ class EntityMemory(Memory):
path=path, path=path,
) )
) )
super().__init__(storage=storage) super().__init__(storage=storage)
self._memory_provider = memory_provider
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
"""Saves an entity item into the SQLite storage.""" """Saves an entity item into the SQLite storage."""
if self.memory_provider == "mem0": if self._memory_provider == "mem0":
data = f""" data = f"""
Remember details about the following entity: Remember details about the following entity:
Name: {item.name} Name: {item.name}

View File

@@ -19,7 +19,6 @@ class ShortTermMemory(Memory):
_memory_provider: Optional[str] = PrivateAttr() _memory_provider: Optional[str] = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None): def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
# Determine memory_provider without assigning it directly as a public field.
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider") memory_provider = crew.memory_config.get("provider")
else: else:
@@ -44,9 +43,7 @@ class ShortTermMemory(Memory):
path=path, path=path,
) )
) )
# First call the parent __init__ so that Pydantic's internals are set.
super().__init__(storage=storage) super().__init__(storage=storage)
# Now assign the private attribute.
self._memory_provider = memory_provider self._memory_provider = memory_provider
def save( def save(