mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
clean up short term memory
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import PrivateAttr
|
||||||
|
|
||||||
from crewai.memory.memory import Memory
|
from crewai.memory.memory import Memory
|
||||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||||
from crewai.memory.storage.rag_storage import RAGStorage
|
from crewai.memory.storage.rag_storage import RAGStorage
|
||||||
@@ -14,15 +16,16 @@ class ShortTermMemory(Memory):
|
|||||||
MemoryItem instances.
|
MemoryItem instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
memory_provider: Any
|
_memory_provider: Optional[str] = PrivateAttr()
|
||||||
|
|
||||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||||
|
# Determine memory_provider without assigning it directly as a public field.
|
||||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||||
self.memory_provider = crew.memory_config.get("provider")
|
memory_provider = crew.memory_config.get("provider")
|
||||||
else:
|
else:
|
||||||
self.memory_provider = None
|
memory_provider = None
|
||||||
|
|
||||||
if self.memory_provider == "mem0":
|
if memory_provider == "mem0":
|
||||||
try:
|
try:
|
||||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -41,7 +44,10 @@ class ShortTermMemory(Memory):
|
|||||||
path=path,
|
path=path,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# First call the parent __init__ so that Pydantic's internals are set.
|
||||||
super().__init__(storage=storage)
|
super().__init__(storage=storage)
|
||||||
|
# Now assign the private attribute.
|
||||||
|
self._memory_provider = memory_provider
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
@@ -50,7 +56,7 @@ class ShortTermMemory(Memory):
|
|||||||
agent: Optional[str] = None,
|
agent: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||||
if self.memory_provider == "mem0":
|
if self._memory_provider == "mem0":
|
||||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||||
|
|
||||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||||
|
|||||||
Reference in New Issue
Block a user