mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
Feat/memory base (#1444)
* byom - short/entity memory * better * rm uneeded * fix text * use context * rm dep and sync * type check fix * fixed test using new cassete * fixing types * fixed types * fix types * fixed types * fixing types * fix type * cassette update * just mock the return of short term mem * remove print * try catch block * added docs * dding error handling here
This commit is contained in:
76
src/crewai/memory/storage/base_rag_storage.py
Normal file
76
src/crewai/memory/storage/base_rag_storage.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
"""
|
||||
Base class for RAG-based Storage implementations.
|
||||
"""
|
||||
|
||||
app: Any | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Any] = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
self.allow_reset = allow_reset
|
||||
self.embedder_config = embedder_config
|
||||
self.crew = crew
|
||||
self.agents = self._initialize_agents()
|
||||
|
||||
def _initialize_agents(self) -> str:
|
||||
if self.crew:
|
||||
return "_".join(
|
||||
[self._sanitize_role(agent.role) for agent in self.crew.agents]
|
||||
)
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
"""Save a value with metadata to the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
"""Search for entries in the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _generate_embedding(
|
||||
self, text: str, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
"""Generate an embedding for the given text and metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_app(self):
|
||||
"""Initialize the vector db."""
|
||||
pass
|
||||
|
||||
def setup_config(self, config: Dict[str, Any]):
|
||||
"""Setup the config of the storage."""
|
||||
pass
|
||||
|
||||
def initialize_client(self):
|
||||
"""Initialize the client of the storage. This should setup the app and the db collection"""
|
||||
pass
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class Storage:
|
||||
@@ -7,7 +7,7 @@ class Storage:
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def search(self, key: str) -> Dict[str, Any]: # type: ignore
|
||||
def search(self, key: str) -> List[Dict[str, Any]]: # type: ignore
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -3,10 +3,11 @@ import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from chromadb.api import ClientAPI
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -24,61 +25,42 @@ def suppress_logging(
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
class RAGStorage(Storage):
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
def __init__(self, type, allow_reset=True, embedder_config=None, crew=None):
|
||||
super().__init__()
|
||||
if (
|
||||
not os.getenv("OPENAI_API_KEY")
|
||||
and not os.getenv("OPENAI_BASE_URL") == "https://api.openai.com/v1"
|
||||
):
|
||||
os.environ["OPENAI_API_KEY"] = "fake"
|
||||
app: ClientAPI | None = None
|
||||
|
||||
def __init__(self, type, allow_reset=True, embedder_config=None, crew=None):
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
self.agents = agents
|
||||
|
||||
config = {
|
||||
"app": {
|
||||
"config": {"name": type, "collect_metrics": False, "log_level": "ERROR"}
|
||||
},
|
||||
"chunker": {
|
||||
"chunk_size": 5000,
|
||||
"chunk_overlap": 100,
|
||||
"length_function": "len",
|
||||
"min_chunk_size": 150,
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chroma",
|
||||
"config": {
|
||||
"collection_name": type,
|
||||
"dir": f"{db_storage_path()}/{type}/{agents}",
|
||||
"allow_reset": allow_reset,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if embedder_config:
|
||||
config["embedder"] = embedder_config
|
||||
self.type = type
|
||||
self.config = config
|
||||
self.embedder_config = embedder_config or self._create_embedding_function()
|
||||
self.allow_reset = allow_reset
|
||||
self._initialize_app()
|
||||
|
||||
def _initialize_app(self):
|
||||
from embedchain import App
|
||||
from embedchain.llm.base import BaseLlm
|
||||
import chromadb
|
||||
|
||||
class FakeLLM(BaseLlm):
|
||||
pass
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=f"{db_storage_path()}/{self.type}/{self.agents}"
|
||||
)
|
||||
self.app = chroma_client
|
||||
|
||||
self.app = App.from_config(config=self.config)
|
||||
self.app.llm = FakeLLM()
|
||||
if self.allow_reset:
|
||||
self.app.reset()
|
||||
try:
|
||||
self.collection = self.app.get_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
except Exception:
|
||||
self.collection = self.app.create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
@@ -87,11 +69,14 @@ class RAGStorage(Storage):
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app"):
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
self._generate_embedding(value, metadata)
|
||||
try:
|
||||
self._generate_embedding(value, metadata)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {str(e)}")
|
||||
|
||||
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
@@ -100,31 +85,50 @@ class RAGStorage(Storage):
|
||||
) -> List[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
from embedchain.vectordb.chroma import InvalidDimensionException
|
||||
|
||||
with suppress_logging():
|
||||
try:
|
||||
results = (
|
||||
self.app.search(query, limit, where=filter)
|
||||
if filter
|
||||
else self.app.search(query, limit)
|
||||
)
|
||||
except InvalidDimensionException:
|
||||
self.app.reset()
|
||||
return []
|
||||
return [r for r in results if r["metadata"]["score"] >= score_threshold]
|
||||
try:
|
||||
with suppress_logging():
|
||||
response = self.collection.query(query_texts=query, n_results=limit)
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
|
||||
if not hasattr(self, "app"):
|
||||
results = []
|
||||
for i in range(len(response["ids"][0])):
|
||||
result = {
|
||||
"id": response["ids"][0][i],
|
||||
"metadata": response["metadatas"][0][i],
|
||||
"context": response["documents"][0][i],
|
||||
"score": response["distances"][0][i],
|
||||
}
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
return []
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
self.app.add(text, data_type=DataType.TEXT, metadata=metadata)
|
||||
self.collection.add(
|
||||
documents=[text],
|
||||
metadatas=[metadata or {}],
|
||||
ids=[str(uuid.uuid4())],
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
shutil.rmtree(f"{db_storage_path()}/{self.type}")
|
||||
if self.app:
|
||||
self.app.reset()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
|
||||
def _create_embedding_function(self):
|
||||
import chromadb.utils.embedding_functions as embedding_functions
|
||||
|
||||
return embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user