chore: apply ruff linting fixes and type annotations to memory module
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
This commit is contained in:
Greyson LaLonde
2025-09-19 22:20:13 -04:00
committed by GitHub
parent d879be8b66
commit 7426969736
9 changed files with 52 additions and 46 deletions

View File

@@ -1,11 +1,11 @@
from .entity.entity_memory import EntityMemory from .entity.entity_memory import EntityMemory
from .external.external_memory import ExternalMemory
from .long_term.long_term_memory import LongTermMemory from .long_term.long_term_memory import LongTermMemory
from .short_term.short_term_memory import ShortTermMemory from .short_term.short_term_memory import ShortTermMemory
from .external.external_memory import ExternalMemory
__all__ = [ __all__ = [
"EntityMemory", "EntityMemory",
"ExternalMemory",
"LongTermMemory", "LongTermMemory",
"ShortTermMemory", "ShortTermMemory",
"ExternalMemory",
] ]

View File

@@ -1,12 +1,12 @@
from typing import Any, Dict, Optional from typing import Any
class ExternalMemoryItem: class ExternalMemoryItem:
def __init__( def __init__(
self, self,
value: Any, value: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
agent: Optional[str] = None, agent: str | None = None,
): ):
self.value = value self.value = value
self.metadata = metadata self.metadata = metadata

View File

@@ -1,17 +1,17 @@
from typing import Any, Dict, List
import time import time
from typing import Any
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import ( from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent, MemoryQueryCompletedEvent,
MemoryQueryFailedEvent, MemoryQueryFailedEvent,
MemorySaveStartedEvent, MemoryQueryStartedEvent,
MemorySaveCompletedEvent, MemorySaveCompletedEvent,
MemorySaveFailedEvent, MemorySaveFailedEvent,
MemorySaveStartedEvent,
) )
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
@@ -84,7 +84,7 @@ class LongTermMemory(Memory):
self, self,
task: str, task: str,
latest_n: int = 3, latest_n: int = 3,
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory" ) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=MemoryQueryStartedEvent( event=MemoryQueryStartedEvent(

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union from typing import Any
class LongTermMemoryItem: class LongTermMemoryItem:
@@ -8,8 +8,8 @@ class LongTermMemoryItem:
task: str, task: str,
expected_output: str, expected_output: str,
datetime: str, datetime: str,
quality: Optional[Union[int, float]] = None, quality: int | float | None = None,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
): ):
self.task = task self.task = task
self.agent = agent self.agent = agent

View File

@@ -1,12 +1,12 @@
from typing import Any, Dict, Optional from typing import Any
class ShortTermMemoryItem: class ShortTermMemoryItem:
def __init__( def __init__(
self, self,
data: Any, data: Any,
agent: Optional[str] = None, agent: str | None = None,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
): ):
self.data = data self.data = data
self.agent = agent self.agent = agent

View File

@@ -1,15 +1,15 @@
from typing import Any, Dict, List from typing import Any
class Storage: class Storage:
"""Abstract base class defining the storage interface""" """Abstract base class defining the storage interface"""
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: dict[str, Any]) -> None:
pass pass
def search( def search(
self, query: str, limit: int, score_threshold: float self, query: str, limit: int, score_threshold: float
) -> Dict[str, Any] | List[Any]: ) -> dict[str, Any] | list[Any]:
return {} return {}
def reset(self) -> None: def reset(self) -> None:

View File

@@ -2,7 +2,7 @@ import json
import logging import logging
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any
from crewai.task import Task from crewai.task import Task
from crewai.utilities import Printer from crewai.utilities import Printer
@@ -18,7 +18,7 @@ class KickoffTaskOutputsSQLiteStorage:
An updated SQLite storage class for kickoff task outputs storage. An updated SQLite storage class for kickoff task outputs storage.
""" """
def __init__(self, db_path: Optional[str] = None) -> None: def __init__(self, db_path: str | None = None) -> None:
if db_path is None: if db_path is None:
# Get the parent directory of the default db path and create our db file there # Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db") db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
@@ -57,15 +57,15 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e: except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e) error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
logger.error(error_msg) logger.error(error_msg)
raise DatabaseOperationError(error_msg, e) raise DatabaseOperationError(error_msg, e) from e
def add( def add(
self, self,
task: Task, task: Task,
output: Dict[str, Any], output: dict[str, Any],
task_index: int, task_index: int,
was_replayed: bool = False, was_replayed: bool = False,
inputs: Dict[str, Any] | None = None, inputs: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Add a new task output record to the database. """Add a new task output record to the database.
@@ -103,7 +103,7 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e: except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e) error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
logger.error(error_msg) logger.error(error_msg)
raise DatabaseOperationError(error_msg, e) raise DatabaseOperationError(error_msg, e) from e
def update( def update(
self, self,
@@ -138,7 +138,7 @@ class KickoffTaskOutputsSQLiteStorage:
else value else value
) )
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
values.append(task_index) values.append(task_index)
cursor.execute(query, tuple(values)) cursor.execute(query, tuple(values))
@@ -151,9 +151,9 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e: except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e) error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
logger.error(error_msg) logger.error(error_msg)
raise DatabaseOperationError(error_msg, e) raise DatabaseOperationError(error_msg, e) from e
def load(self) -> List[Dict[str, Any]]: def load(self) -> list[dict[str, Any]]:
"""Load all task output records from the database. """Load all task output records from the database.
Returns: Returns:
@@ -192,7 +192,7 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e: except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e) error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
logger.error(error_msg) logger.error(error_msg)
raise DatabaseOperationError(error_msg, e) raise DatabaseOperationError(error_msg, e) from e
def delete_all(self) -> None: def delete_all(self) -> None:
"""Delete all task output records from the database. """Delete all task output records from the database.
@@ -212,4 +212,4 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e: except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e) error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
logger.error(error_msg) logger.error(error_msg)
raise DatabaseOperationError(error_msg, e) raise DatabaseOperationError(error_msg, e) from e

View File

@@ -1,7 +1,7 @@
import json import json
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any
from crewai.utilities import Printer from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
@@ -12,9 +12,7 @@ class LTMSQLiteStorage:
An updated SQLite storage class for LTM data storage. An updated SQLite storage class for LTM data storage.
""" """
def __init__( def __init__(self, db_path: str | None = None) -> None:
self, db_path: Optional[str] = None
) -> None:
if db_path is None: if db_path is None:
# Get the parent directory of the default db path and create our db file there # Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db") db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
@@ -53,9 +51,9 @@ class LTMSQLiteStorage:
def save( def save(
self, self,
task_description: str, task_description: str,
metadata: Dict[str, Any], metadata: dict[str, Any],
datetime: str, datetime: str,
score: Union[int, float], score: int | float,
) -> None: ) -> None:
"""Saves data to the LTM table with error handling.""" """Saves data to the LTM table with error handling."""
try: try:
@@ -75,9 +73,7 @@ class LTMSQLiteStorage:
color="red", color="red",
) )
def load( def load(self, task_description: str, latest_n: int) -> list[dict[str, Any]] | None:
self, task_description: str, latest_n: int
) -> Optional[List[Dict[str, Any]]]:
"""Queries the LTM table by task description with error handling.""" """Queries the LTM table by task description with error handling."""
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@@ -89,7 +85,7 @@ class LTMSQLiteStorage:
WHERE task_description = ? WHERE task_description = ?
ORDER BY datetime DESC, score ASC ORDER BY datetime DESC, score ASC
LIMIT {latest_n} LIMIT {latest_n}
""", # nosec """, # nosec # noqa: S608
(task_description,), (task_description,),
) )
rows = cursor.fetchall() rows = cursor.fetchall()
@@ -125,4 +121,4 @@ class LTMSQLiteStorage:
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}", content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red", color="red",
) )
return None return

View File

@@ -1,9 +1,10 @@
import logging import logging
import traceback import traceback
import warnings import warnings
from typing import Any from typing import Any, cast
from crewai.rag.chromadb.config import ChromaDBConfig from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function from crewai.rag.embeddings.factory import get_embedding_function
@@ -21,8 +22,13 @@ class RAGStorage(BaseRAGStorage):
""" """
def __init__( def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None self,
): type: str,
allow_reset: bool = True,
embedder_config: dict[str, Any] | None = None,
crew: Any = None,
path: str | None = None,
) -> None:
super().__init__(type, allow_reset, embedder_config, crew) super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else [] agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents] agents = [self._sanitize_role(agent.role) for agent in agents]
@@ -44,7 +50,11 @@ class RAGStorage(BaseRAGStorage):
if self.embedder_config: if self.embedder_config:
embedding_function = get_embedding_function(self.embedder_config) embedding_function = get_embedding_function(self.embedder_config)
config = ChromaDBConfig(embedding_function=embedding_function) config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config) self._client = create_client(config)
def _get_client(self) -> BaseClient: def _get_client(self) -> BaseClient: