feat: add mypy as type checker, update code and add comment to reference (#591)

* fix: fix test actually running

* fix: fix test to not send request to openai

* fix: fix linting to remove cli files

* fix: exclude only files that breaks black

* fix: Fix all Ruff checkings on the code and Fix Test with repeated name

* fix: Change linter name on yml file

* feat: update pre-commit

* feat: remove need for isort on the code

* feat: add mypy as type checker, update code and add comment to reference

* feat: remove black linter

* feat: remove poetry to run the command

* feat: change logic to test mypy

* feat: update tests yml to try to fix the tests gh action

* feat: try to add just mypy to run on gh action

* feat: fix yml file

* feat: add comment to avoid issue on gh action

* feat: decouple pytest from the necessity of poetry install

* feat: change tests.yml to test different approach

* feat: change to poetry run

* fix: parameter field on yml file

* fix: update parameters to be on the pyproject

* fix: update pyproject to remove import untyped errors
This commit is contained in:
Eduardo Chiarotti
2024-05-10 16:37:52 -03:00
committed by GitHub
parent 8430c2f9af
commit 1ec4da6947
23 changed files with 1064 additions and 310 deletions

View File

@@ -1,3 +1,5 @@
from typing import Optional
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory
@@ -32,7 +34,7 @@ class ContextualMemory:
formatted_results = "\n".join([f"- {result}" for result in stm_results])
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> str:
def _fetch_ltm_context(self, task) -> Optional[str]:
"""
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
@@ -44,10 +46,10 @@ class ContextualMemory:
formatted_results = [
suggestion
for result in ltm_results
for suggestion in result["metadata"]["suggestions"]
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
]
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results])
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
@@ -58,6 +60,6 @@ class ContextualMemory:
"""
em_results = self.em.search(query)
formatted_results = "\n".join(
[f"- {result['context']}" for result in em_results]
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
)
return f"Entities:\n{formatted_results}" if em_results else ""

View File

@@ -16,7 +16,7 @@ class EntityMemory(Memory):
)
super().__init__(storage)
def save(self, item: EntityMemoryItem) -> None:
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
"""Saves an entity item into the SQLite storage."""
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)

View File

@@ -18,10 +18,10 @@ class LongTermMemory(Memory):
storage = LTMSQLiteStorage()
super().__init__(storage)
def save(self, item: LongTermMemoryItem) -> None:
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
metadata = item.metadata
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
self.storage.save(
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
@@ -29,4 +29,4 @@ class LongTermMemory(Memory):
)
def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]:
return self.storage.load(task, latest_n)
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union
class LongTermMemoryItem:
@@ -8,8 +8,8 @@ class LongTermMemoryItem:
task: str,
expected_output: str,
datetime: str,
quality: Union[int, float] = None,
metadata: Dict[str, Any] = None,
quality: Optional[Union[int, float]] = None,
metadata: Optional[Dict[str, Any]] = None,
):
self.task = task
self.agent = agent

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional
from crewai.memory.storage.interface import Storage
@@ -12,12 +12,16 @@ class Memory:
self.storage = storage
def save(
self, value: Any, metadata: Dict[str, Any] = None, agent: str = None
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
metadata = metadata or {}
if agent:
metadata["agent"] = agent
self.storage.save(value, metadata)
self.storage.save(value, metadata) # type: ignore # Maybe BUG? Should be self.storage.save(key, value, metadata)
def search(self, query: str) -> Dict[str, Any]:
return self.storage.search(query)

View File

@@ -16,8 +16,8 @@ class ShortTermMemory(Memory):
storage = RAGStorage(type="short_term", embedder_config=embedder_config)
super().__init__(storage)
def save(self, item: ShortTermMemoryItem) -> None:
def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
super().save(item.data, item.metadata, item.agent)
def search(self, query: str, score_threshold: float = 0.35):
return self.storage.search(query=query, score_threshold=score_threshold)
return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters

View File

@@ -1,8 +1,10 @@
from typing import Any, Dict
from typing import Any, Dict, Optional
class ShortTermMemoryItem:
def __init__(self, data: Any, agent: str, metadata: Dict[str, Any] = None):
def __init__(
self, data: Any, agent: str, metadata: Optional[Dict[str, Any]] = None
):
self.data = data
self.agent = agent
self.metadata = metadata if metadata is not None else {}

View File

@@ -7,5 +7,5 @@ class Storage:
def save(self, key: str, value: Any, metadata: Dict[str, Any]) -> None:
pass
def search(self, key: str) -> Dict[str, Any]:
def search(self, key: str) -> Dict[str, Any]: # type: ignore
pass

View File

@@ -1,6 +1,6 @@
import json
import sqlite3
from typing import Any, Dict, Union
from typing import Any, Dict, List, Optional, Union
from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path
@@ -11,7 +11,9 @@ class LTMSQLiteStorage:
An updated SQLite storage class for LTM data storage.
"""
def __init__(self, db_path=f"{db_storage_path()}/long_term_memory_storage.db"):
def __init__(
self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db"
) -> None:
self.db_path = db_path
self._printer: Printer = Printer()
self._initialize_db()
@@ -67,7 +69,9 @@ class LTMSQLiteStorage:
color="red",
)
def load(self, task_description: str, latest_n: int) -> Dict[str, Any]:
def load(
self, task_description: str, latest_n: int
) -> Optional[List[Dict[str, Any]]]:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:

View File

@@ -2,7 +2,7 @@ import contextlib
import io
import logging
import os
from typing import Any, Dict
from typing import Any, Dict, List, Optional
from embedchain import App
from embedchain.llm.base import BaseLlm
@@ -72,16 +72,16 @@ class RAGStorage(Storage):
if allow_reset:
self.app.reset()
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: Dict[str, Any]) -> None: # type: ignore # BUG?: Should be save(key, value, metadata) Signature of "save" incompatible with supertype "Storage"
self._generate_embedding(value, metadata)
def search(
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
self,
query: str,
limit: int = 3,
filter: dict = None,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> Dict[str, Any]:
) -> List[Any]:
with suppress_logging():
try:
results = (