mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
chore: apply ruff linting fixes and type annotations to memory module
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
This commit is contained in:
@@ -1,11 +1,11 @@
|
|||||||
from .entity.entity_memory import EntityMemory
|
from .entity.entity_memory import EntityMemory
|
||||||
|
from .external.external_memory import ExternalMemory
|
||||||
from .long_term.long_term_memory import LongTermMemory
|
from .long_term.long_term_memory import LongTermMemory
|
||||||
from .short_term.short_term_memory import ShortTermMemory
|
from .short_term.short_term_memory import ShortTermMemory
|
||||||
from .external.external_memory import ExternalMemory
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EntityMemory",
|
"EntityMemory",
|
||||||
|
"ExternalMemory",
|
||||||
"LongTermMemory",
|
"LongTermMemory",
|
||||||
"ShortTermMemory",
|
"ShortTermMemory",
|
||||||
"ExternalMemory",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class ExternalMemoryItem:
|
class ExternalMemoryItem:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
value: Any,
|
value: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
agent: Optional[str] = None,
|
agent: str | None = None,
|
||||||
):
|
):
|
||||||
self.value = value
|
self.value = value
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
from typing import Any, Dict, List
|
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
|
||||||
from crewai.memory.memory import Memory
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.memory_events import (
|
from crewai.events.types.memory_events import (
|
||||||
MemoryQueryStartedEvent,
|
|
||||||
MemoryQueryCompletedEvent,
|
MemoryQueryCompletedEvent,
|
||||||
MemoryQueryFailedEvent,
|
MemoryQueryFailedEvent,
|
||||||
MemorySaveStartedEvent,
|
MemoryQueryStartedEvent,
|
||||||
MemorySaveCompletedEvent,
|
MemorySaveCompletedEvent,
|
||||||
MemorySaveFailedEvent,
|
MemorySaveFailedEvent,
|
||||||
|
MemorySaveStartedEvent,
|
||||||
)
|
)
|
||||||
|
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||||
|
from crewai.memory.memory import Memory
|
||||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||||
|
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ class LongTermMemory(Memory):
|
|||||||
self,
|
self,
|
||||||
task: str,
|
task: str,
|
||||||
latest_n: int = 3,
|
latest_n: int = 3,
|
||||||
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryQueryStartedEvent(
|
event=MemoryQueryStartedEvent(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class LongTermMemoryItem:
|
class LongTermMemoryItem:
|
||||||
@@ -8,8 +8,8 @@ class LongTermMemoryItem:
|
|||||||
task: str,
|
task: str,
|
||||||
expected_output: str,
|
expected_output: str,
|
||||||
datetime: str,
|
datetime: str,
|
||||||
quality: Optional[Union[int, float]] = None,
|
quality: int | float | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
self.task = task
|
self.task = task
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class ShortTermMemoryItem:
|
class ShortTermMemoryItem:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data: Any,
|
data: Any,
|
||||||
agent: Optional[str] = None,
|
agent: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
self.data = data
|
self.data = data
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class Storage:
|
class Storage:
|
||||||
"""Abstract base class defining the storage interface"""
|
"""Abstract base class defining the storage interface"""
|
||||||
|
|
||||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, query: str, limit: int, score_threshold: float
|
self, query: str, limit: int, score_threshold: float
|
||||||
) -> Dict[str, Any] | List[Any]:
|
) -> dict[str, Any] | list[Any]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.utilities import Printer
|
from crewai.utilities import Printer
|
||||||
@@ -18,7 +18,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
An updated SQLite storage class for kickoff task outputs storage.
|
An updated SQLite storage class for kickoff task outputs storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db_path: Optional[str] = None) -> None:
|
def __init__(self, db_path: str | None = None) -> None:
|
||||||
if db_path is None:
|
if db_path is None:
|
||||||
# Get the parent directory of the default db path and create our db file there
|
# Get the parent directory of the default db path and create our db file there
|
||||||
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||||
@@ -57,15 +57,15 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
|
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise DatabaseOperationError(error_msg, e)
|
raise DatabaseOperationError(error_msg, e) from e
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
task: Task,
|
task: Task,
|
||||||
output: Dict[str, Any],
|
output: dict[str, Any],
|
||||||
task_index: int,
|
task_index: int,
|
||||||
was_replayed: bool = False,
|
was_replayed: bool = False,
|
||||||
inputs: Dict[str, Any] | None = None,
|
inputs: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a new task output record to the database.
|
"""Add a new task output record to the database.
|
||||||
|
|
||||||
@@ -103,7 +103,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
|
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise DatabaseOperationError(error_msg, e)
|
raise DatabaseOperationError(error_msg, e) from e
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@@ -138,7 +138,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
else value
|
else value
|
||||||
)
|
)
|
||||||
|
|
||||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec
|
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
|
||||||
values.append(task_index)
|
values.append(task_index)
|
||||||
|
|
||||||
cursor.execute(query, tuple(values))
|
cursor.execute(query, tuple(values))
|
||||||
@@ -151,9 +151,9 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
|
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise DatabaseOperationError(error_msg, e)
|
raise DatabaseOperationError(error_msg, e) from e
|
||||||
|
|
||||||
def load(self) -> List[Dict[str, Any]]:
|
def load(self) -> list[dict[str, Any]]:
|
||||||
"""Load all task output records from the database.
|
"""Load all task output records from the database.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -192,7 +192,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
|
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise DatabaseOperationError(error_msg, e)
|
raise DatabaseOperationError(error_msg, e) from e
|
||||||
|
|
||||||
def delete_all(self) -> None:
|
def delete_all(self) -> None:
|
||||||
"""Delete all task output records from the database.
|
"""Delete all task output records from the database.
|
||||||
@@ -212,4 +212,4 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
|
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise DatabaseOperationError(error_msg, e)
|
raise DatabaseOperationError(error_msg, e) from e
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from crewai.utilities import Printer
|
from crewai.utilities import Printer
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
@@ -12,9 +12,7 @@ class LTMSQLiteStorage:
|
|||||||
An updated SQLite storage class for LTM data storage.
|
An updated SQLite storage class for LTM data storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_path: str | None = None) -> None:
|
||||||
self, db_path: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
if db_path is None:
|
if db_path is None:
|
||||||
# Get the parent directory of the default db path and create our db file there
|
# Get the parent directory of the default db path and create our db file there
|
||||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||||
@@ -53,9 +51,9 @@ class LTMSQLiteStorage:
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
task_description: str,
|
task_description: str,
|
||||||
metadata: Dict[str, Any],
|
metadata: dict[str, Any],
|
||||||
datetime: str,
|
datetime: str,
|
||||||
score: Union[int, float],
|
score: int | float,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Saves data to the LTM table with error handling."""
|
"""Saves data to the LTM table with error handling."""
|
||||||
try:
|
try:
|
||||||
@@ -75,9 +73,7 @@ class LTMSQLiteStorage:
|
|||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
||||||
def load(
|
def load(self, task_description: str, latest_n: int) -> list[dict[str, Any]] | None:
|
||||||
self, task_description: str, latest_n: int
|
|
||||||
) -> Optional[List[Dict[str, Any]]]:
|
|
||||||
"""Queries the LTM table by task description with error handling."""
|
"""Queries the LTM table by task description with error handling."""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
@@ -89,7 +85,7 @@ class LTMSQLiteStorage:
|
|||||||
WHERE task_description = ?
|
WHERE task_description = ?
|
||||||
ORDER BY datetime DESC, score ASC
|
ORDER BY datetime DESC, score ASC
|
||||||
LIMIT {latest_n}
|
LIMIT {latest_n}
|
||||||
""", # nosec
|
""", # nosec # noqa: S608
|
||||||
(task_description,),
|
(task_description,),
|
||||||
)
|
)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
@@ -125,4 +121,4 @@ class LTMSQLiteStorage:
|
|||||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return None
|
return
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||||
|
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||||
from crewai.rag.config.utils import get_rag_client
|
from crewai.rag.config.utils import get_rag_client
|
||||||
from crewai.rag.core.base_client import BaseClient
|
from crewai.rag.core.base_client import BaseClient
|
||||||
from crewai.rag.embeddings.factory import get_embedding_function
|
from crewai.rag.embeddings.factory import get_embedding_function
|
||||||
@@ -21,8 +22,13 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
self,
|
||||||
):
|
type: str,
|
||||||
|
allow_reset: bool = True,
|
||||||
|
embedder_config: dict[str, Any] | None = None,
|
||||||
|
crew: Any = None,
|
||||||
|
path: str | None = None,
|
||||||
|
) -> None:
|
||||||
super().__init__(type, allow_reset, embedder_config, crew)
|
super().__init__(type, allow_reset, embedder_config, crew)
|
||||||
agents = crew.agents if crew else []
|
agents = crew.agents if crew else []
|
||||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||||
@@ -44,7 +50,11 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
|
|
||||||
if self.embedder_config:
|
if self.embedder_config:
|
||||||
embedding_function = get_embedding_function(self.embedder_config)
|
embedding_function = get_embedding_function(self.embedder_config)
|
||||||
config = ChromaDBConfig(embedding_function=embedding_function)
|
config = ChromaDBConfig(
|
||||||
|
embedding_function=cast(
|
||||||
|
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||||
|
)
|
||||||
|
)
|
||||||
self._client = create_client(config)
|
self._client = create_client(config)
|
||||||
|
|
||||||
def _get_client(self) -> BaseClient:
|
def _get_client(self) -> BaseClient:
|
||||||
|
|||||||
Reference in New Issue
Block a user