From 58b2ba4d900123a24296962f0c1ec05d2920a2e9 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 28 Dec 2024 00:56:10 +0000 Subject: [PATCH] refactor: update database connections to use storage_path Co-Authored-By: Joe Moura --- .../memory/contextual/contextual_memory.py | 9 +- src/crewai/memory/storage/base_rag_storage.py | 97 ++++++++++++++++--- .../storage/kickoff_task_outputs_storage.py | 54 ++++++++--- .../memory/storage/ltm_sqlite_storage.py | 45 +++++++-- 4 files changed, 171 insertions(+), 34 deletions(-) diff --git a/src/crewai/memory/contextual/contextual_memory.py b/src/crewai/memory/contextual/contextual_memory.py index b7baaa92c..fc1e8ca90 100644 --- a/src/crewai/memory/contextual/contextual_memory.py +++ b/src/crewai/memory/contextual/contextual_memory.py @@ -1,4 +1,5 @@ from typing import Any, Dict, Optional +from crewai.task import Task from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory @@ -21,7 +22,7 @@ class ContextualMemory: self.em = em self.um = um - def build_context_for_task(self, task, context) -> str: + def build_context_for_task(self, task: Task, context: str) -> str: """ Automatically builds a minimal, highly relevant set of contextual information for a given task. @@ -39,7 +40,7 @@ class ContextualMemory: context.append(self._fetch_user_context(query)) return "\n".join(filter(None, context)) - def _fetch_stm_context(self, query) -> str: + def _fetch_stm_context(self, query: str) -> str: """ Fetches recent relevant insights from STM related to the task's description and expected_output, formatted as bullet points. @@ -53,7 +54,7 @@ class ContextualMemory: ) return f"Recent Insights:\n{formatted_results}" if stm_results else "" - def _fetch_ltm_context(self, task) -> Optional[str]: + def _fetch_ltm_context(self, task: str) -> Optional[str]: """ Fetches historical data or insights from LTM that are relevant to the task's description and expected_output, formatted as bullet points. @@ -72,7 +73,7 @@ class ContextualMemory: return f"Historical Data:\n{formatted_results}" if ltm_results else "" - def _fetch_entity_context(self, query) -> str: + def _fetch_entity_context(self, query: str) -> str: """ Fetches relevant entity information from Entity Memory related to the task's description and expected_output, formatted as bullet points. diff --git a/src/crewai/memory/storage/base_rag_storage.py b/src/crewai/memory/storage/base_rag_storage.py index 9c26bf293..9ff827484 100644 --- a/src/crewai/memory/storage/base_rag_storage.py +++ b/src/crewai/memory/storage/base_rag_storage.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional +import os +from typing import Any, Dict, List, Optional, TypeVar +from abc import ABC, abstractmethod +from pathlib import Path from crewai.utilities.paths import db_storage_path @@ -19,15 +22,42 @@ class BaseRAGStorage(ABC): allow_reset: bool = True, embedder_config: Optional[Any] = None, crew: Any = None, - ): + ) -> None: + """Initialize the BaseRAGStorage. + + Args: + type: Type of storage being used + storage_path: Optional custom path for storage location + allow_reset: Whether storage can be reset + embedder_config: Optional configuration for the embedder + crew: Optional crew instance this storage belongs to + + Raises: + PermissionError: If storage path is not writable + OSError: If storage path cannot be created + """ self.type = type self.storage_path = storage_path if storage_path else db_storage_path() + + # Validate storage path + try: + self.storage_path.parent.mkdir(parents=True, exist_ok=True) + if not os.access(self.storage_path.parent, os.W_OK): + raise PermissionError(f"No write permission for storage path: {self.storage_path}") + except OSError as e: + raise OSError(f"Failed to initialize storage path: {str(e)}") + self.allow_reset = allow_reset self.embedder_config = embedder_config self.crew = crew self.agents = self._initialize_agents() def _initialize_agents(self) -> str: + """Initialize agent identifiers for storage. + + Returns: + str: Underscore-joined string of sanitized agent role names + """ if self.crew: return "_".join( [self._sanitize_role(agent.role) for agent in self.crew.agents] @@ -36,12 +66,27 @@ class BaseRAGStorage(ABC): @abstractmethod def _sanitize_role(self, role: str) -> str: - """Sanitizes agent roles to ensure valid directory names.""" + """Sanitizes agent roles to ensure valid directory names. + + Args: + role: The agent role name to sanitize + + Returns: + str: Sanitized role name safe for use in paths + """ pass @abstractmethod def save(self, value: Any, metadata: Dict[str, Any]) -> None: - """Save a value with metadata to the storage.""" + """Save a value with metadata to the storage. + + Args: + value: The value to store + metadata: Additional metadata to store with the value + + Raises: + OSError: If there is an error writing to storage + """ pass @abstractmethod @@ -51,25 +96,55 @@ class BaseRAGStorage(ABC): limit: int = 3, filter: Optional[dict] = None, score_threshold: float = 0.35, - ) -> List[Any]: - """Search for entries in the storage.""" + ) -> List[Dict[str, Any]]: + """Search for entries in the storage. + + Args: + query: The search query string + limit: Maximum number of results to return + filter: Optional filter criteria + score_threshold: Minimum similarity score threshold + + Returns: + List[Dict[str, Any]]: List of matching entries with their metadata + """ pass @abstractmethod def reset(self) -> None: - """Reset the storage.""" + """Reset the storage. + + Raises: + OSError: If there is an error clearing storage + PermissionError: If reset is not allowed + """ pass @abstractmethod def _generate_embedding( self, text: str, metadata: Optional[Dict[str, Any]] = None - ) -> Any: - """Generate an embedding for the given text and metadata.""" + ) -> List[float]: + """Generate an embedding for the given text and metadata. + + Args: + text: Text to generate embedding for + metadata: Optional metadata to include in embedding + + Returns: + List[float]: Vector embedding of the text + + Raises: + ValueError: If text is empty or invalid + """ pass @abstractmethod - def _initialize_app(self): - """Initialize the vector db.""" + def _initialize_app(self) -> None: + """Initialize the vector db. + + Raises: + OSError: If vector db initialization fails + """ pass def setup_config(self, config: Dict[str, Any]): diff --git a/src/crewai/memory/storage/kickoff_task_outputs_storage.py b/src/crewai/memory/storage/kickoff_task_outputs_storage.py index 00e949d39..284e9639c 100644 --- a/src/crewai/memory/storage/kickoff_task_outputs_storage.py +++ b/src/crewai/memory/storage/kickoff_task_outputs_storage.py @@ -1,5 +1,7 @@ import json +import os import sqlite3 +from pathlib import Path from typing import Any, Dict, List, Optional from crewai.task import Task @@ -13,12 +15,30 @@ class KickoffTaskOutputsSQLiteStorage: An updated SQLite storage class for kickoff task outputs storage. """ - def __init__(self, db_path: Optional[str] = None) -> None: - self.db_path = ( - db_path - if db_path - else f"{db_storage_path()}/latest_kickoff_task_outputs.db" + def __init__(self, storage_path: Optional[Path] = None) -> None: + """Initialize kickoff task outputs storage. + + Args: + storage_path: Optional custom path for storage location + + Raises: + PermissionError: If storage path is not writable + OSError: If storage path cannot be created + """ + self.storage_path = ( + storage_path + if storage_path + else Path(f"{db_storage_path()}/latest_kickoff_task_outputs.db") ) + + # Validate storage path + try: + self.storage_path.parent.mkdir(parents=True, exist_ok=True) + if not os.access(self.storage_path.parent, os.W_OK): + raise PermissionError(f"No write permission for storage path: {self.storage_path}") + except OSError as e: + raise OSError(f"Failed to initialize storage path: {str(e)}") + self._printer: Printer = Printer() self._initialize_db() @@ -27,7 +47,7 @@ class KickoffTaskOutputsSQLiteStorage: Initializes the SQLite database and creates LTM table """ try: - with sqlite3.connect(self.db_path) as conn: + with sqlite3.connect(str(self.storage_path)) as conn: cursor = conn.cursor() cursor.execute( """ @@ -57,9 +77,21 @@ class KickoffTaskOutputsSQLiteStorage: task_index: int, was_replayed: bool = False, inputs: Dict[str, Any] = {}, - ): + ) -> None: + """Add a task output to storage. + + Args: + task: The task whose output is being stored + output: The output data from the task + task_index: Index of this task in the sequence + was_replayed: Whether this was from a replay + inputs: Optional input data that led to this output + + Raises: + sqlite3.Error: If there is an error saving to database + """ try: - with sqlite3.connect(self.db_path) as conn: + with sqlite3.connect(str(self.storage_path)) as conn: cursor = conn.cursor() cursor.execute( """ @@ -92,7 +124,7 @@ class KickoffTaskOutputsSQLiteStorage: Updates an existing row in the latest_kickoff_task_outputs table based on task_index. """ try: - with sqlite3.connect(self.db_path) as conn: + with sqlite3.connect(str(self.storage_path)) as conn: cursor = conn.cursor() fields = [] @@ -121,7 +153,7 @@ class KickoffTaskOutputsSQLiteStorage: def load(self) -> Optional[List[Dict[str, Any]]]: try: - with sqlite3.connect(self.db_path) as conn: + with sqlite3.connect(str(self.storage_path)) as conn: cursor = conn.cursor() cursor.execute(""" SELECT * @@ -157,7 +189,7 @@ class KickoffTaskOutputsSQLiteStorage: Deletes all rows from the latest_kickoff_task_outputs table. """ try: - with sqlite3.connect(self.db_path) as conn: + with sqlite3.connect(str(self.storage_path)) as conn: cursor = conn.cursor() cursor.execute("DELETE FROM latest_kickoff_task_outputs") conn.commit() diff --git a/src/crewai/memory/storage/ltm_sqlite_storage.py b/src/crewai/memory/storage/ltm_sqlite_storage.py index 97ccfa14b..8a61cdfc1 100644 --- a/src/crewai/memory/storage/ltm_sqlite_storage.py +++ b/src/crewai/memory/storage/ltm_sqlite_storage.py @@ -1,5 +1,7 @@ import json +import os import sqlite3 +from pathlib import Path from typing import Any, Dict, List, Optional, Union from crewai.utilities import Printer @@ -11,10 +13,26 @@ class LTMSQLiteStorage: An updated SQLite storage class for LTM data storage. """ - def __init__(self, db_path: Optional[str] = None) -> None: - self.db_path = ( - db_path if db_path else f"{db_storage_path()}/latest_long_term_memories.db" - ) + def __init__(self, storage_path: Optional[Path] = None) -> None: + """Initialize LTM SQLite storage. + + Args: + storage_path: Optional custom path for storage location + + Raises: + PermissionError: If storage path is not writable + OSError: If storage path cannot be created + """ + self.storage_path = storage_path if storage_path else Path(f"{db_storage_path()}/latest_long_term_memories.db") + + # Validate storage path + try: + self.storage_path.parent.mkdir(parents=True, exist_ok=True) + if not os.access(self.storage_path.parent, os.W_OK): + raise PermissionError(f"No write permission for storage path: {self.storage_path}") + except OSError as e: + raise OSError(f"Failed to initialize storage path: {str(e)}") + self._printer: Printer = Printer() self._initialize_db() @@ -23,7 +41,7 @@ class LTMSQLiteStorage: Initializes the SQLite database and creates LTM table """ try: - with sqlite3.connect(self.db_path) as conn: + with sqlite3.connect(str(self.storage_path)) as conn: cursor = conn.cursor() cursor.execute( """ @@ -51,9 +69,20 @@ class LTMSQLiteStorage: datetime: str, score: Union[int, float], ) -> None: + """Save a memory entry to long-term memory. + + Args: + task_description: Description of the task this memory relates to + metadata: Additional data to store with the memory + datetime: Timestamp for when this memory was created + score: Relevance score for this memory (higher is more relevant) + + Raises: + sqlite3.Error: If there is an error saving to the database + """ """Saves data to the LTM table with error handling.""" try: - with sqlite3.connect(self.db_path) as conn: + with sqlite3.connect(str(self.storage_path)) as conn: cursor = conn.cursor() cursor.execute( """ @@ -74,7 +103,7 @@ class LTMSQLiteStorage: ) -> Optional[List[Dict[str, Any]]]: """Queries the LTM table by task description with error handling.""" try: - with sqlite3.connect(self.db_path) as conn: + with sqlite3.connect(str(self.storage_path)) as conn: cursor = conn.cursor() cursor.execute( f""" @@ -109,7 +138,7 @@ class LTMSQLiteStorage: ) -> None: """Resets the LTM table with error handling.""" try: - with sqlite3.connect(self.db_path) as conn: + with sqlite3.connect(str(self.storage_path)) as conn: cursor = conn.cursor() cursor.execute("DELETE FROM long_term_memories") conn.commit()