fixes mypy issues

This commit is contained in:
Dev-Khant
2024-09-23 16:02:04 +05:30
parent 7371a454ad
commit 66a1ecca00
5 changed files with 46 additions and 14 deletions

View File

@@ -23,5 +23,13 @@ class Memory:
self.storage.save(value, metadata) self.storage.save(value, metadata)
def search(self, query: str) -> Dict[str, Any]: def search(
return self.storage.search(query) self,
query: str,
limit: int = 3,
filters: dict = {},
score_threshold: float = 0.35,
) -> Dict[str, Any]:
return self.storage.search(
query=query, limit=limit, filters=filters, score_threshold=score_threshold
)

View File

@@ -42,8 +42,16 @@ class ShortTermMemory(Memory):
super().save(value=item.data, metadata=item.metadata, agent=item.agent) super().save(value=item.data, metadata=item.metadata, agent=item.agent)
def search(self, query: str, score_threshold: float = 0.35): def search(
return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters self,
query: str,
limit: int = 3,
filters: dict = {},
score_threshold: float = 0.35,
):
return self.storage.search(
query=query, limit=limit, filters=filters, score_threshold=score_threshold
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
def reset(self) -> None: def reset(self) -> None:
try: try:

View File

@@ -7,7 +7,9 @@ class Storage:
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: Dict[str, Any]) -> None:
pass pass
def search(self, key: str) -> Dict[str, Any]: # type: ignore def search(
self, query: str, limit: int, filters: Dict, score_threshold: float
) -> Dict[str, Any]: # type: ignore
pass pass
def reset(self) -> None: def reset(self) -> None:

View File

@@ -92,14 +92,14 @@ class RAGStorage(Storage):
self, self,
query: str, query: str,
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, filters: Optional[dict] = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Any]: ) -> List[Any]:
with suppress_logging(): with suppress_logging():
try: try:
results = ( results = (
self.app.search(query, limit, where=filter) self.app.search(query, limit, where=filters)
if filter if filters
else self.app.search(query, limit) else self.app.search(query, limit)
) )
except InvalidDimensionException: except InvalidDimensionException:

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, Optional
from crewai.memory.memory import Memory from crewai.memory.memory import Memory
from crewai.memory.user.user_memory_item import UserMemoryItem
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
@@ -15,9 +16,22 @@ class UserMemory(Memory):
storage = Mem0Storage(type="user", crew=crew) storage = Mem0Storage(type="user", crew=crew)
super().__init__(storage) super().__init__(storage)
def save(self, item: UserMemoryItem) -> None: def save(
data = f"Remember the details about the user: {item.data}" self,
super().save(data, item.metadata, user=item.user) value,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
data = f"Remember the details about the user: {value}"
super().save(data, metadata)
def search(self, query: str, score_threshold: float = 0.35): def search(
return self.storage.search(query=query, score_threshold=score_threshold) self,
query: str,
limit: int = 3,
filters: dict = {},
score_threshold: float = 0.35,
):
return super().search(
query=query, limit=limit, filters=filters, score_threshold=score_threshold
)