mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
committed by
GitHub
parent
0dfe3bcb0a
commit
5d8f8cbc79
@@ -5,11 +5,6 @@ import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.models.data_type import DataType
|
||||
from embedchain.vectordb.chroma import InvalidDimensionException
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
@@ -29,10 +24,6 @@ def suppress_logging(
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
class FakeLLM(BaseLlm):
|
||||
pass
|
||||
|
||||
|
||||
class RAGStorage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
@@ -74,9 +65,19 @@ class RAGStorage(Storage):
|
||||
if embedder_config:
|
||||
config["embedder"] = embedder_config
|
||||
self.type = type
|
||||
self.app = App.from_config(config=config)
|
||||
self.config = config
|
||||
self.allow_reset = allow_reset
|
||||
|
||||
def _initialize_app(self):
|
||||
from embedchain import App
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
class FakeLLM(BaseLlm):
|
||||
pass
|
||||
|
||||
self.app = App.from_config(config=self.config)
|
||||
self.app.llm = FakeLLM()
|
||||
if allow_reset:
|
||||
if self.allow_reset:
|
||||
self.app.reset()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
@@ -86,6 +87,8 @@ class RAGStorage(Storage):
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
self._generate_embedding(value, metadata)
|
||||
|
||||
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
|
||||
@@ -95,6 +98,10 @@ class RAGStorage(Storage):
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
from embedchain.vectordb.chroma import InvalidDimensionException
|
||||
|
||||
with suppress_logging():
|
||||
try:
|
||||
results = (
|
||||
@@ -108,6 +115,10 @@ class RAGStorage(Storage):
|
||||
return [r for r in results if r["metadata"]["score"] >= score_threshold]
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
self.app.add(text, data_type=DataType.TEXT, metadata=metadata)
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user