fixes mypy issues

This commit is contained in:
Dev-Khant
2024-09-23 16:02:04 +05:30
parent 7371a454ad
commit 66a1ecca00
5 changed files with 46 additions and 14 deletions

View File

@@ -23,5 +23,13 @@ class Memory:
self.storage.save(value, metadata)
def search(self, query: str) -> Dict[str, Any]:
return self.storage.search(query)
def search(
self,
query: str,
limit: int = 3,
filters: dict = {},
score_threshold: float = 0.35,
) -> Dict[str, Any]:
return self.storage.search(
query=query, limit=limit, filters=filters, score_threshold=score_threshold
)

View File

@@ -42,8 +42,16 @@ class ShortTermMemory(Memory):
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
def search(self, query: str, score_threshold: float = 0.35):
return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
def search(
self,
query: str,
limit: int = 3,
filters: dict = {},
score_threshold: float = 0.35,
):
return self.storage.search(
query=query, limit=limit, filters=filters, score_threshold=score_threshold
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
def reset(self) -> None:
try:

View File

@@ -7,7 +7,9 @@ class Storage:
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
pass
def search(self, key: str) -> Dict[str, Any]: # type: ignore
def search(
self, query: str, limit: int, filters: Dict, score_threshold: float
) -> Dict[str, Any]: # type: ignore
pass
def reset(self) -> None:

View File

@@ -92,14 +92,14 @@ class RAGStorage(Storage):
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
filters: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
with suppress_logging():
try:
results = (
self.app.search(query, limit, where=filter)
if filter
self.app.search(query, limit, where=filters)
if filters
else self.app.search(query, limit)
)
except InvalidDimensionException:

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, Optional
from crewai.memory.memory import Memory
from crewai.memory.user.user_memory_item import UserMemoryItem
from crewai.memory.storage.mem0_storage import Mem0Storage
@@ -15,9 +16,22 @@ class UserMemory(Memory):
storage = Mem0Storage(type="user", crew=crew)
super().__init__(storage)
def save(self, item: UserMemoryItem) -> None:
data = f"Remember the details about the user: {item.data}"
super().save(data, item.metadata, user=item.user)
def save(
self,
value,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
data = f"Remember the details about the user: {value}"
super().save(data, metadata)
def search(self, query: str, score_threshold: float = 0.35):
return self.storage.search(query=query, score_threshold=score_threshold)
def search(
self,
query: str,
limit: int = 3,
filters: dict = {},
score_threshold: float = 0.35,
):
return super().search(
query=query, limit=limit, filters=filters, score_threshold=score_threshold
)