From 74269697369e9ce7afc01a48ace8d0b1e2f795d6 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 19 Sep 2025 22:20:13 -0400 Subject: [PATCH] chore: apply ruff linting fixes and type annotations to memory module Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> --- src/crewai/memory/__init__.py | 4 ++-- .../memory/external/external_memory_item.py | 6 ++--- .../memory/long_term/long_term_memory.py | 12 +++++----- .../memory/long_term/long_term_memory_item.py | 6 ++--- .../short_term/short_term_memory_item.py | 6 ++--- src/crewai/memory/storage/interface.py | 6 ++--- .../storage/kickoff_task_outputs_storage.py | 22 +++++++++---------- .../memory/storage/ltm_sqlite_storage.py | 18 ++++++--------- src/crewai/memory/storage/rag_storage.py | 18 +++++++++++---- 9 files changed, 52 insertions(+), 46 deletions(-) diff --git a/src/crewai/memory/__init__.py b/src/crewai/memory/__init__.py index c80c41f70..0c8aacdde 100644 --- a/src/crewai/memory/__init__.py +++ b/src/crewai/memory/__init__.py @@ -1,11 +1,11 @@ from .entity.entity_memory import EntityMemory +from .external.external_memory import ExternalMemory from .long_term.long_term_memory import LongTermMemory from .short_term.short_term_memory import ShortTermMemory -from .external.external_memory import ExternalMemory __all__ = [ "EntityMemory", + "ExternalMemory", "LongTermMemory", "ShortTermMemory", - "ExternalMemory", ] diff --git a/src/crewai/memory/external/external_memory_item.py b/src/crewai/memory/external/external_memory_item.py index c97cccd59..f66b16c3d 100644 --- a/src/crewai/memory/external/external_memory_item.py +++ b/src/crewai/memory/external/external_memory_item.py @@ -1,12 +1,12 @@ -from typing import Any, Dict, Optional +from typing import Any class ExternalMemoryItem: def __init__( self, value: Any, - metadata: Optional[Dict[str, Any]] = None, - agent: Optional[str] = None, + metadata: dict[str, Any] | None = None, + agent: str | None = None, ): self.value = value self.metadata = metadata diff --git a/src/crewai/memory/long_term/long_term_memory.py b/src/crewai/memory/long_term/long_term_memory.py index 35460ef84..038d07e83 100644 --- a/src/crewai/memory/long_term/long_term_memory.py +++ b/src/crewai/memory/long_term/long_term_memory.py @@ -1,17 +1,17 @@ -from typing import Any, Dict, List import time +from typing import Any -from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem -from crewai.memory.memory import Memory from crewai.events.event_bus import crewai_event_bus from crewai.events.types.memory_events import ( - MemoryQueryStartedEvent, MemoryQueryCompletedEvent, MemoryQueryFailedEvent, - MemorySaveStartedEvent, + MemoryQueryStartedEvent, MemorySaveCompletedEvent, MemorySaveFailedEvent, + MemorySaveStartedEvent, ) +from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem +from crewai.memory.memory import Memory from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage @@ -84,7 +84,7 @@ class LongTermMemory(Memory): self, task: str, latest_n: int = 3, - ) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory" + ) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory" crewai_event_bus.emit( self, event=MemoryQueryStartedEvent( diff --git a/src/crewai/memory/long_term/long_term_memory_item.py b/src/crewai/memory/long_term/long_term_memory_item.py index b2164f242..5196b2548 100644 --- a/src/crewai/memory/long_term/long_term_memory_item.py +++ b/src/crewai/memory/long_term/long_term_memory_item.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any class LongTermMemoryItem: @@ -8,8 +8,8 @@ class LongTermMemoryItem: task: str, expected_output: str, datetime: str, - quality: Optional[Union[int, float]] = None, - metadata: Optional[Dict[str, Any]] = None, + quality: int | float | None = None, + metadata: dict[str, Any] | None = None, ): self.task = task self.agent = agent diff --git a/src/crewai/memory/short_term/short_term_memory_item.py b/src/crewai/memory/short_term/short_term_memory_item.py index 83b7f842f..d04a291e1 100644 --- a/src/crewai/memory/short_term/short_term_memory_item.py +++ b/src/crewai/memory/short_term/short_term_memory_item.py @@ -1,12 +1,12 @@ -from typing import Any, Dict, Optional +from typing import Any class ShortTermMemoryItem: def __init__( self, data: Any, - agent: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, + agent: str | None = None, + metadata: dict[str, Any] | None = None, ): self.data = data self.agent = agent diff --git a/src/crewai/memory/storage/interface.py b/src/crewai/memory/storage/interface.py index 8bec9a14f..90634bce7 100644 --- a/src/crewai/memory/storage/interface.py +++ b/src/crewai/memory/storage/interface.py @@ -1,15 +1,15 @@ -from typing import Any, Dict, List +from typing import Any class Storage: """Abstract base class defining the storage interface""" - def save(self, value: Any, metadata: Dict[str, Any]) -> None: + def save(self, value: Any, metadata: dict[str, Any]) -> None: pass def search( self, query: str, limit: int, score_threshold: float - ) -> Dict[str, Any] | List[Any]: + ) -> dict[str, Any] | list[Any]: return {} def reset(self) -> None: diff --git a/src/crewai/memory/storage/kickoff_task_outputs_storage.py b/src/crewai/memory/storage/kickoff_task_outputs_storage.py index c84c54f1c..c8643a153 100644 --- a/src/crewai/memory/storage/kickoff_task_outputs_storage.py +++ b/src/crewai/memory/storage/kickoff_task_outputs_storage.py @@ -2,7 +2,7 @@ import json import logging import sqlite3 from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any from crewai.task import Task from crewai.utilities import Printer @@ -18,7 +18,7 @@ class KickoffTaskOutputsSQLiteStorage: An updated SQLite storage class for kickoff task outputs storage. """ - def __init__(self, db_path: Optional[str] = None) -> None: + def __init__(self, db_path: str | None = None) -> None: if db_path is None: # Get the parent directory of the default db path and create our db file there db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db") @@ -57,15 +57,15 @@ class KickoffTaskOutputsSQLiteStorage: except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e) logger.error(error_msg) - raise DatabaseOperationError(error_msg, e) + raise DatabaseOperationError(error_msg, e) from e def add( self, task: Task, - output: Dict[str, Any], + output: dict[str, Any], task_index: int, was_replayed: bool = False, - inputs: Dict[str, Any] | None = None, + inputs: dict[str, Any] | None = None, ) -> None: """Add a new task output record to the database. @@ -103,7 +103,7 @@ class KickoffTaskOutputsSQLiteStorage: except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e) logger.error(error_msg) - raise DatabaseOperationError(error_msg, e) + raise DatabaseOperationError(error_msg, e) from e def update( self, @@ -138,7 +138,7 @@ class KickoffTaskOutputsSQLiteStorage: else value ) - query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec + query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608 values.append(task_index) cursor.execute(query, tuple(values)) @@ -151,9 +151,9 @@ class KickoffTaskOutputsSQLiteStorage: except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e) logger.error(error_msg) - raise DatabaseOperationError(error_msg, e) + raise DatabaseOperationError(error_msg, e) from e - def load(self) -> List[Dict[str, Any]]: + def load(self) -> list[dict[str, Any]]: """Load all task output records from the database. Returns: @@ -192,7 +192,7 @@ class KickoffTaskOutputsSQLiteStorage: except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e) logger.error(error_msg) - raise DatabaseOperationError(error_msg, e) + raise DatabaseOperationError(error_msg, e) from e def delete_all(self) -> None: """Delete all task output records from the database. @@ -212,4 +212,4 @@ class KickoffTaskOutputsSQLiteStorage: except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e) logger.error(error_msg) - raise DatabaseOperationError(error_msg, e) + raise DatabaseOperationError(error_msg, e) from e diff --git a/src/crewai/memory/storage/ltm_sqlite_storage.py b/src/crewai/memory/storage/ltm_sqlite_storage.py index 35f54e0e7..abf117c63 100644 --- a/src/crewai/memory/storage/ltm_sqlite_storage.py +++ b/src/crewai/memory/storage/ltm_sqlite_storage.py @@ -1,7 +1,7 @@ import json import sqlite3 from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any from crewai.utilities import Printer from crewai.utilities.paths import db_storage_path @@ -12,9 +12,7 @@ class LTMSQLiteStorage: An updated SQLite storage class for LTM data storage. """ - def __init__( - self, db_path: Optional[str] = None - ) -> None: + def __init__(self, db_path: str | None = None) -> None: if db_path is None: # Get the parent directory of the default db path and create our db file there db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db") @@ -53,9 +51,9 @@ class LTMSQLiteStorage: def save( self, task_description: str, - metadata: Dict[str, Any], + metadata: dict[str, Any], datetime: str, - score: Union[int, float], + score: int | float, ) -> None: """Saves data to the LTM table with error handling.""" try: @@ -75,9 +73,7 @@ class LTMSQLiteStorage: color="red", ) - def load( - self, task_description: str, latest_n: int - ) -> Optional[List[Dict[str, Any]]]: + def load(self, task_description: str, latest_n: int) -> list[dict[str, Any]] | None: """Queries the LTM table by task description with error handling.""" try: with sqlite3.connect(self.db_path) as conn: @@ -89,7 +85,7 @@ class LTMSQLiteStorage: WHERE task_description = ? ORDER BY datetime DESC, score ASC LIMIT {latest_n} - """, # nosec + """, # nosec # noqa: S608 (task_description,), ) rows = cursor.fetchall() @@ -125,4 +121,4 @@ class LTMSQLiteStorage: content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}", color="red", ) - return None + return diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 7e66a262c..f1ae919bc 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -1,9 +1,10 @@ import logging import traceback import warnings -from typing import Any +from typing import Any, cast from crewai.rag.chromadb.config import ChromaDBConfig +from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper from crewai.rag.config.utils import get_rag_client from crewai.rag.core.base_client import BaseClient from crewai.rag.embeddings.factory import get_embedding_function @@ -21,8 +22,13 @@ class RAGStorage(BaseRAGStorage): """ def __init__( - self, type, allow_reset=True, embedder_config=None, crew=None, path=None - ): + self, + type: str, + allow_reset: bool = True, + embedder_config: dict[str, Any] | None = None, + crew: Any = None, + path: str | None = None, + ) -> None: super().__init__(type, allow_reset, embedder_config, crew) agents = crew.agents if crew else [] agents = [self._sanitize_role(agent.role) for agent in agents] @@ -44,7 +50,11 @@ class RAGStorage(BaseRAGStorage): if self.embedder_config: embedding_function = get_embedding_function(self.embedder_config) - config = ChromaDBConfig(embedding_function=embedding_function) + config = ChromaDBConfig( + embedding_function=cast( + ChromaEmbeddingFunctionWrapper, embedding_function + ) + ) self._client = create_client(config) def _get_client(self) -> BaseClient: