chore: apply ruff linting fixes and type annotations to memory module
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
This commit is contained in:
Greyson LaLonde
2025-09-19 22:20:13 -04:00
committed by GitHub
parent d879be8b66
commit 7426969736
9 changed files with 52 additions and 46 deletions

View File

@@ -1,11 +1,11 @@
from .entity.entity_memory import EntityMemory
from .external.external_memory import ExternalMemory
from .long_term.long_term_memory import LongTermMemory
from .short_term.short_term_memory import ShortTermMemory
from .external.external_memory import ExternalMemory
__all__ = [
"EntityMemory",
"ExternalMemory",
"LongTermMemory",
"ShortTermMemory",
"ExternalMemory",
]

View File

@@ -1,12 +1,12 @@
from typing import Any, Dict, Optional
from typing import Any
class ExternalMemoryItem:
def __init__(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
metadata: dict[str, Any] | None = None,
agent: str | None = None,
):
self.value = value
self.metadata = metadata

View File

@@ -1,17 +1,17 @@
from typing import Any, Dict, List
import time
from typing import Any
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
@@ -84,7 +84,7 @@ class LongTermMemory(Memory):
self,
task: str,
latest_n: int = 3,
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union
from typing import Any
class LongTermMemoryItem:
@@ -8,8 +8,8 @@ class LongTermMemoryItem:
task: str,
expected_output: str,
datetime: str,
quality: Optional[Union[int, float]] = None,
metadata: Optional[Dict[str, Any]] = None,
quality: int | float | None = None,
metadata: dict[str, Any] | None = None,
):
self.task = task
self.agent = agent

View File

@@ -1,12 +1,12 @@
from typing import Any, Dict, Optional
from typing import Any
class ShortTermMemoryItem:
def __init__(
self,
data: Any,
agent: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
agent: str | None = None,
metadata: dict[str, Any] | None = None,
):
self.data = data
self.agent = agent

View File

@@ -1,15 +1,15 @@
from typing import Any, Dict, List
from typing import Any
class Storage:
"""Abstract base class defining the storage interface"""
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
pass
def search(
self, query: str, limit: int, score_threshold: float
) -> Dict[str, Any] | List[Any]:
) -> dict[str, Any] | list[Any]:
return {}
def reset(self) -> None:

View File

@@ -2,7 +2,7 @@ import json
import logging
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
from crewai.task import Task
from crewai.utilities import Printer
@@ -18,7 +18,7 @@ class KickoffTaskOutputsSQLiteStorage:
An updated SQLite storage class for kickoff task outputs storage.
"""
def __init__(self, db_path: Optional[str] = None) -> None:
def __init__(self, db_path: str | None = None) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
@@ -57,15 +57,15 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
logger.error(error_msg)
raise DatabaseOperationError(error_msg, e)
raise DatabaseOperationError(error_msg, e) from e
def add(
self,
task: Task,
output: Dict[str, Any],
output: dict[str, Any],
task_index: int,
was_replayed: bool = False,
inputs: Dict[str, Any] | None = None,
inputs: dict[str, Any] | None = None,
) -> None:
"""Add a new task output record to the database.
@@ -103,7 +103,7 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
logger.error(error_msg)
raise DatabaseOperationError(error_msg, e)
raise DatabaseOperationError(error_msg, e) from e
def update(
self,
@@ -138,7 +138,7 @@ class KickoffTaskOutputsSQLiteStorage:
else value
)
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
values.append(task_index)
cursor.execute(query, tuple(values))
@@ -151,9 +151,9 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
logger.error(error_msg)
raise DatabaseOperationError(error_msg, e)
raise DatabaseOperationError(error_msg, e) from e
def load(self) -> List[Dict[str, Any]]:
def load(self) -> list[dict[str, Any]]:
"""Load all task output records from the database.
Returns:
@@ -192,7 +192,7 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
logger.error(error_msg)
raise DatabaseOperationError(error_msg, e)
raise DatabaseOperationError(error_msg, e) from e
def delete_all(self) -> None:
"""Delete all task output records from the database.
@@ -212,4 +212,4 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
logger.error(error_msg)
raise DatabaseOperationError(error_msg, e)
raise DatabaseOperationError(error_msg, e) from e

View File

@@ -1,7 +1,7 @@
import json
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any
from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path
@@ -12,9 +12,7 @@ class LTMSQLiteStorage:
An updated SQLite storage class for LTM data storage.
"""
def __init__(
self, db_path: Optional[str] = None
) -> None:
def __init__(self, db_path: str | None = None) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
@@ -53,9 +51,9 @@ class LTMSQLiteStorage:
def save(
self,
task_description: str,
metadata: Dict[str, Any],
metadata: dict[str, Any],
datetime: str,
score: Union[int, float],
score: int | float,
) -> None:
"""Saves data to the LTM table with error handling."""
try:
@@ -75,9 +73,7 @@ class LTMSQLiteStorage:
color="red",
)
def load(
self, task_description: str, latest_n: int
) -> Optional[List[Dict[str, Any]]]:
def load(self, task_description: str, latest_n: int) -> list[dict[str, Any]] | None:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -89,7 +85,7 @@ class LTMSQLiteStorage:
WHERE task_description = ?
ORDER BY datetime DESC, score ASC
LIMIT {latest_n}
""", # nosec
""", # nosec # noqa: S608
(task_description,),
)
rows = cursor.fetchall()
@@ -125,4 +121,4 @@ class LTMSQLiteStorage:
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
return None
return

View File

@@ -1,9 +1,10 @@
import logging
import traceback
import warnings
from typing import Any
from typing import Any, cast
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
@@ -21,8 +22,13 @@ class RAGStorage(BaseRAGStorage):
"""
def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
):
self,
type: str,
allow_reset: bool = True,
embedder_config: dict[str, Any] | None = None,
crew: Any = None,
path: str | None = None,
) -> None:
super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents]
@@ -44,7 +50,11 @@ class RAGStorage(BaseRAGStorage):
if self.embedder_config:
embedding_function = get_embedding_function(self.embedder_config)
config = ChromaDBConfig(embedding_function=embedding_function)
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config)
def _get_client(self) -> BaseClient: