Adding long term, short term, entity and contextual memory

This commit is contained in:
João Moura
2024-04-01 04:45:56 -03:00
parent a6c3b1f1d4
commit 5b59e450f7
30 changed files with 709 additions and 83 deletions

View File

@@ -0,0 +1,3 @@
from .entity.entity_memory import EntityMemory
from .long_term.long_term_memory import LongTermMemory
from .short_term.short_term_memory import ShortTermMemory

View File

View File

@@ -0,0 +1,58 @@
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory
class ContextualMemory:
def __init__(self, stm: ShortTermMemory, ltm: LongTermMemory, em: EntityMemory):
self.stm = stm
self.ltm = ltm
self.em = em
def build_context_for_task(self, task, context) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
"""
query = f"{task.description} {context}".strip()
if query == "":
return ""
context = []
context.append(self._fetch_ltm_context(task.description))
context.append(self._fetch_stm_context(query))
context.append(self._fetch_entity_context(query))
return "\n".join(filter(None, context))
def _fetch_stm_context(self, query) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points.
"""
stm_results = self.stm.search(query)
formatted_results = "\n".join([f"- {result}" for result in stm_results])
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> str:
"""
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
"""
ltm_results = self.ltm.search(task)
if not ltm_results:
return None
formatted_results = "\n".join(
[f"{result['metadata']['suggestions']}" for result in ltm_results]
)
formatted_results = list(set(formatted_results))
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str:
"""
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points.
"""
em_results = self.em.search(query)
formatted_results = "\n".join(
[f"- {result['context']}" for result in em_results]
)
return f"Entities:\n{formatted_results}" if em_results else ""

View File

View File

@@ -0,0 +1,22 @@
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
class EntityMemory(Memory):
"""
EntityMemory class for managing structured information about entities
and their relationships using SQLite storage.
Inherits from the Memory class.
"""
def __init__(self, embedder_config=None):
storage = RAGStorage(
type="entities", allow_reset=False, embedder_config=embedder_config
)
super().__init__(storage)
def save(self, item: EntityMemoryItem) -> None:
"""Saves an entity item into the SQLite storage."""
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)

View File

@@ -0,0 +1,12 @@
class EntityMemoryItem:
def __init__(
self,
name: str,
type: str,
description: str,
relationships: str,
):
self.name = name
self.type = type
self.description = description
self.metadata = {"relationships": relationships}

View File

View File

@@ -0,0 +1,32 @@
from typing import Any, Dict
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
class LongTermMemory(Memory):
"""
LongTermMemory class for managing cross runs data related to overall crew's
execution and performance.
Inherits from the Memory class and utilizes an instance of a class that
adheres to the Storage for data storage, specifically working with
LongTermMemoryItem instances.
"""
def __init__(self):
storage = LTMSQLiteStorage()
super().__init__(storage)
def save(self, item: LongTermMemoryItem) -> None:
metadata = item.metadata
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
self.storage.save(
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
datetime=item.datetime,
)
def search(self, task: str) -> Dict[str, Any]:
return self.storage.load(task)

View File

@@ -0,0 +1,19 @@
from typing import Any, Dict, Union
class LongTermMemoryItem:
def __init__(
self,
agent: str,
task: str,
expected_output: str,
datetime: str,
quality: Union[int, float] = None,
metadata: Dict[str, Any] = None,
):
self.task = task
self.agent = agent
self.quality = quality
self.datetime = datetime
self.expected_output = expected_output
self.metadata = metadata if metadata is not None else {}

View File

@@ -0,0 +1,23 @@
from typing import Any, Dict
from crewai.memory.storage.interface import Storage
class Memory:
"""
Base class for memory, now supporting agent tags and generic metadata.
"""
def __init__(self, storage: Storage):
self.storage = storage
def save(
self, value: Any, metadata: Dict[str, Any] = None, agent: str = None
) -> None:
metadata = metadata or {}
if agent:
metadata["agent"] = agent
self.storage.save(value, metadata)
def search(self, query: str) -> Dict[str, Any]:
return self.storage.search(query)

View File

View File

@@ -0,0 +1,23 @@
from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage
class ShortTermMemory(Memory):
"""
ShortTermMemory class for managing transient data related to immediate tasks
and interactions.
Inherits from the Memory class and utilizes an instance of a class that
adheres to the Storage for data storage, specifically working with
MemoryItem instances.
"""
def __init__(self, embedder_config=None):
storage = RAGStorage(type="short_term", embedder_config=embedder_config)
super().__init__(storage)
def save(self, item: ShortTermMemoryItem) -> None:
super().save(item.data, item.metadata, item.agent)
def search(self, query: str, score_threshold: float = 0.35):
return self.storage.search(query=query, score_threshold=score_threshold)

View File

@@ -0,0 +1,8 @@
from typing import Any, Dict
class ShortTermMemoryItem:
def __init__(self, data: Any, agent: str, metadata: Dict[str, Any] = None):
self.data = data
self.agent = agent
self.metadata = metadata if metadata is not None else {}

View File

@@ -0,0 +1,11 @@
from typing import Any, Dict
class Storage:
"""Abstract base class defining the storage interface"""
def save(self, key: str, value: Any, metadata: Dict[str, Any]) -> None:
pass
def search(self, key: str) -> Dict[str, Any]:
pass

View File

@@ -0,0 +1,100 @@
import json
import sqlite3
from typing import Any, Dict, Union
from crewai.utilities import Printer
class LTMSQLiteStorage:
"""
An updated SQLite storage class for LTM data storage.
"""
def __init__(self, db_path=".db/long_term_memory_storage.db"):
self.db_path = db_path
self._printer: Printer = Printer()
self._initialize_db()
def _initialize_db(self):
"""
Initializes the SQLite database and creates LTM table
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS long_term_memories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_description TEXT,
metadata TEXT,
datetime TEXT,
score REAL
)
"""
)
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred during database initialization: {e}",
color="red",
)
def save(
self,
task_description: str,
metadata: Dict[str, Any],
datetime: str,
score: Union[int, float],
) -> None:
"""Saves data to the LTM table with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
VALUES (?, ?, ?, ?)
""",
(task_description, json.dumps(metadata), datetime, score),
)
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
color="red",
)
def load(self, task_description: str) -> Dict[str, Any]:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT metadata, datetime, score
FROM long_term_memories
WHERE task_description = ?
ORDER BY datetime DESC, score ASC
LIMIT 2
""",
(task_description,),
)
rows = cursor.fetchall()
if rows:
return [
{
"metadata": json.loads(row[0]),
"datetime": row[1],
"score": row[2],
}
for row in rows
]
except sqlite3.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
color="red",
)
return None

View File

@@ -0,0 +1,87 @@
import contextlib
import io
import logging
from typing import Any, Dict
from embedchain import App
from embedchain.llm.base import BaseLlm
from crewai.memory.storage.interface import Storage
@contextlib.contextmanager
def suppress_logging(
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
level=logging.ERROR,
):
logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel()
logger.setLevel(level)
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
io.StringIO()
), contextlib.suppress(UserWarning):
yield
logger.setLevel(original_level)
class FakeLLM(BaseLlm):
pass
class RAGStorage(Storage):
"""
Extends Storage to handle embeddings for memory entries, improving
search efficiency.
"""
def __init__(self, type, allow_reset=True, embedder_config=None):
super().__init__()
config = {
"app": {
"config": {"name": type, "collect_metrics": False, "log_level": "ERROR"}
},
"chunker": {
"chunk_size": 5000,
"chunk_overlap": 100,
"length_function": "len",
"min_chunk_size": 150,
},
"vectordb": {
"provider": "chroma",
"config": {
"collection_name": type,
"dir": f".db/{type}",
"allow_reset": allow_reset,
},
},
}
if embedder_config:
config["embedder"] = embedder_config
self.app = App.from_config(config=config)
self.app.llm = FakeLLM()
if allow_reset:
self.app.reset()
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
self._generate_embedding(value, metadata)
def search(
self,
query: str,
limit: int = 3,
filter: dict = None,
score_threshold: float = 0.35,
) -> Dict[str, Any]:
with suppress_logging():
results = (
self.app.search(query, limit, where=filter)
if filter
else self.app.search(query, limit)
)
return [r for r in results if r["metadata"]["score"] >= score_threshold]
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
with suppress_logging():
self.app.add(text, data_type="text", metadata=metadata)