mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
131 lines
4.1 KiB
Python
131 lines
4.1 KiB
Python
import contextlib
|
|
import io
|
|
import logging
|
|
import os
|
|
import shutil
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from crewai.memory.storage.interface import Storage
|
|
from crewai.utilities.paths import db_storage_path
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def suppress_logging(
|
|
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
|
|
level=logging.ERROR,
|
|
):
|
|
logger = logging.getLogger(logger_name)
|
|
original_level = logger.getEffectiveLevel()
|
|
logger.setLevel(level)
|
|
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
io.StringIO()
|
|
), contextlib.suppress(UserWarning):
|
|
yield
|
|
logger.setLevel(original_level)
|
|
|
|
|
|
class RAGStorage(Storage):
|
|
"""
|
|
Extends Storage to handle embeddings for memory entries, improving
|
|
search efficiency.
|
|
"""
|
|
|
|
def __init__(self, type, allow_reset=True, embedder_config=None, crew=None):
|
|
super().__init__()
|
|
if (
|
|
not os.getenv("OPENAI_API_KEY")
|
|
and not os.getenv("OPENAI_BASE_URL") == "https://api.openai.com/v1"
|
|
):
|
|
os.environ["OPENAI_API_KEY"] = "fake"
|
|
|
|
agents = crew.agents if crew else []
|
|
agents = [self._sanitize_role(agent.role) for agent in agents]
|
|
agents = "_".join(agents)
|
|
|
|
config = {
|
|
"app": {
|
|
"config": {"name": type, "collect_metrics": False, "log_level": "ERROR"}
|
|
},
|
|
"chunker": {
|
|
"chunk_size": 5000,
|
|
"chunk_overlap": 100,
|
|
"length_function": "len",
|
|
"min_chunk_size": 150,
|
|
},
|
|
"vectordb": {
|
|
"provider": "chroma",
|
|
"config": {
|
|
"collection_name": type,
|
|
"dir": f"{db_storage_path()}/{type}/{agents}",
|
|
"allow_reset": allow_reset,
|
|
},
|
|
},
|
|
}
|
|
|
|
if embedder_config:
|
|
config["embedder"] = embedder_config
|
|
self.type = type
|
|
self.config = config
|
|
self.allow_reset = allow_reset
|
|
|
|
def _initialize_app(self):
|
|
from embedchain import App
|
|
from embedchain.llm.base import BaseLlm
|
|
|
|
class FakeLLM(BaseLlm):
|
|
pass
|
|
|
|
self.app = App.from_config(config=self.config)
|
|
self.app.llm = FakeLLM()
|
|
if self.allow_reset:
|
|
self.app.reset()
|
|
|
|
def _sanitize_role(self, role: str) -> str:
|
|
"""
|
|
Sanitizes agent roles to ensure valid directory names.
|
|
"""
|
|
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
|
|
|
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
|
if not hasattr(self, "app"):
|
|
self._initialize_app()
|
|
self._generate_embedding(value, metadata)
|
|
|
|
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
|
|
self,
|
|
query: str,
|
|
limit: int = 3,
|
|
filter: Optional[dict] = None,
|
|
score_threshold: float = 0.35,
|
|
) -> List[Any]:
|
|
if not hasattr(self, "app"):
|
|
self._initialize_app()
|
|
from embedchain.vectordb.chroma import InvalidDimensionException
|
|
|
|
with suppress_logging():
|
|
try:
|
|
results = (
|
|
self.app.search(query, limit, where=filter)
|
|
if filter
|
|
else self.app.search(query, limit)
|
|
)
|
|
except InvalidDimensionException:
|
|
self.app.reset()
|
|
return []
|
|
return [r for r in results if r["metadata"]["score"] >= score_threshold]
|
|
|
|
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
|
|
if not hasattr(self, "app"):
|
|
self._initialize_app()
|
|
from embedchain.models.data_type import DataType
|
|
|
|
self.app.add(text, data_type=DataType.TEXT, metadata=metadata)
|
|
|
|
def reset(self) -> None:
|
|
try:
|
|
shutil.rmtree(f"{db_storage_path()}/{self.type}")
|
|
except Exception as e:
|
|
raise Exception(
|
|
f"An error occurred while resetting the {self.type} memory: {e}"
|
|
)
|