mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
refactor: update database connections to use storage_path
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory
|
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ class ContextualMemory:
|
|||||||
self.em = em
|
self.em = em
|
||||||
self.um = um
|
self.um = um
|
||||||
|
|
||||||
def build_context_for_task(self, task, context) -> str:
|
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||||
"""
|
"""
|
||||||
Automatically builds a minimal, highly relevant set of contextual information
|
Automatically builds a minimal, highly relevant set of contextual information
|
||||||
for a given task.
|
for a given task.
|
||||||
@@ -39,7 +40,7 @@ class ContextualMemory:
|
|||||||
context.append(self._fetch_user_context(query))
|
context.append(self._fetch_user_context(query))
|
||||||
return "\n".join(filter(None, context))
|
return "\n".join(filter(None, context))
|
||||||
|
|
||||||
def _fetch_stm_context(self, query) -> str:
|
def _fetch_stm_context(self, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||||
formatted as bullet points.
|
formatted as bullet points.
|
||||||
@@ -53,7 +54,7 @@ class ContextualMemory:
|
|||||||
)
|
)
|
||||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||||
|
|
||||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
def _fetch_ltm_context(self, task: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||||
formatted as bullet points.
|
formatted as bullet points.
|
||||||
@@ -72,7 +73,7 @@ class ContextualMemory:
|
|||||||
|
|
||||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||||
|
|
||||||
def _fetch_entity_context(self, query) -> str:
|
def _fetch_entity_context(self, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||||
formatted as bullet points.
|
formatted as bullet points.
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional, TypeVar
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
@@ -19,15 +22,42 @@ class BaseRAGStorage(ABC):
|
|||||||
allow_reset: bool = True,
|
allow_reset: bool = True,
|
||||||
embedder_config: Optional[Any] = None,
|
embedder_config: Optional[Any] = None,
|
||||||
crew: Any = None,
|
crew: Any = None,
|
||||||
):
|
) -> None:
|
||||||
|
"""Initialize the BaseRAGStorage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type: Type of storage being used
|
||||||
|
storage_path: Optional custom path for storage location
|
||||||
|
allow_reset: Whether storage can be reset
|
||||||
|
embedder_config: Optional configuration for the embedder
|
||||||
|
crew: Optional crew instance this storage belongs to
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionError: If storage path is not writable
|
||||||
|
OSError: If storage path cannot be created
|
||||||
|
"""
|
||||||
self.type = type
|
self.type = type
|
||||||
self.storage_path = storage_path if storage_path else db_storage_path()
|
self.storage_path = storage_path if storage_path else db_storage_path()
|
||||||
|
|
||||||
|
# Validate storage path
|
||||||
|
try:
|
||||||
|
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not os.access(self.storage_path.parent, os.W_OK):
|
||||||
|
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
|
||||||
|
except OSError as e:
|
||||||
|
raise OSError(f"Failed to initialize storage path: {str(e)}")
|
||||||
|
|
||||||
self.allow_reset = allow_reset
|
self.allow_reset = allow_reset
|
||||||
self.embedder_config = embedder_config
|
self.embedder_config = embedder_config
|
||||||
self.crew = crew
|
self.crew = crew
|
||||||
self.agents = self._initialize_agents()
|
self.agents = self._initialize_agents()
|
||||||
|
|
||||||
def _initialize_agents(self) -> str:
|
def _initialize_agents(self) -> str:
|
||||||
|
"""Initialize agent identifiers for storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Underscore-joined string of sanitized agent role names
|
||||||
|
"""
|
||||||
if self.crew:
|
if self.crew:
|
||||||
return "_".join(
|
return "_".join(
|
||||||
[self._sanitize_role(agent.role) for agent in self.crew.agents]
|
[self._sanitize_role(agent.role) for agent in self.crew.agents]
|
||||||
@@ -36,12 +66,27 @@ class BaseRAGStorage(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _sanitize_role(self, role: str) -> str:
|
def _sanitize_role(self, role: str) -> str:
|
||||||
"""Sanitizes agent roles to ensure valid directory names."""
|
"""Sanitizes agent roles to ensure valid directory names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role: The agent role name to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sanitized role name safe for use in paths
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||||
"""Save a value with metadata to the storage."""
|
"""Save a value with metadata to the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to store
|
||||||
|
metadata: Additional metadata to store with the value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: If there is an error writing to storage
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -51,25 +96,55 @@ class BaseRAGStorage(ABC):
|
|||||||
limit: int = 3,
|
limit: int = 3,
|
||||||
filter: Optional[dict] = None,
|
filter: Optional[dict] = None,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.35,
|
||||||
) -> List[Any]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Search for entries in the storage."""
|
"""Search for entries in the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query string
|
||||||
|
limit: Maximum number of results to return
|
||||||
|
filter: Optional filter criteria
|
||||||
|
score_threshold: Minimum similarity score threshold
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: List of matching entries with their metadata
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset the storage."""
|
"""Reset the storage.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: If there is an error clearing storage
|
||||||
|
PermissionError: If reset is not allowed
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _generate_embedding(
|
def _generate_embedding(
|
||||||
self, text: str, metadata: Optional[Dict[str, Any]] = None
|
self, text: str, metadata: Optional[Dict[str, Any]] = None
|
||||||
) -> Any:
|
) -> List[float]:
|
||||||
"""Generate an embedding for the given text and metadata."""
|
"""Generate an embedding for the given text and metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to generate embedding for
|
||||||
|
metadata: Optional metadata to include in embedding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: Vector embedding of the text
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If text is empty or invalid
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _initialize_app(self):
|
def _initialize_app(self) -> None:
|
||||||
"""Initialize the vector db."""
|
"""Initialize the vector db.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: If vector db initialization fails
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def setup_config(self, config: Dict[str, Any]):
|
def setup_config(self, config: Dict[str, Any]):
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
@@ -13,12 +15,30 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
An updated SQLite storage class for kickoff task outputs storage.
|
An updated SQLite storage class for kickoff task outputs storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db_path: Optional[str] = None) -> None:
|
def __init__(self, storage_path: Optional[Path] = None) -> None:
|
||||||
self.db_path = (
|
"""Initialize kickoff task outputs storage.
|
||||||
db_path
|
|
||||||
if db_path
|
Args:
|
||||||
else f"{db_storage_path()}/latest_kickoff_task_outputs.db"
|
storage_path: Optional custom path for storage location
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionError: If storage path is not writable
|
||||||
|
OSError: If storage path cannot be created
|
||||||
|
"""
|
||||||
|
self.storage_path = (
|
||||||
|
storage_path
|
||||||
|
if storage_path
|
||||||
|
else Path(f"{db_storage_path()}/latest_kickoff_task_outputs.db")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate storage path
|
||||||
|
try:
|
||||||
|
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not os.access(self.storage_path.parent, os.W_OK):
|
||||||
|
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
|
||||||
|
except OSError as e:
|
||||||
|
raise OSError(f"Failed to initialize storage path: {str(e)}")
|
||||||
|
|
||||||
self._printer: Printer = Printer()
|
self._printer: Printer = Printer()
|
||||||
self._initialize_db()
|
self._initialize_db()
|
||||||
|
|
||||||
@@ -27,7 +47,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
Initializes the SQLite database and creates LTM table
|
Initializes the SQLite database and creates LTM table
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
@@ -57,9 +77,21 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
task_index: int,
|
task_index: int,
|
||||||
was_replayed: bool = False,
|
was_replayed: bool = False,
|
||||||
inputs: Dict[str, Any] = {},
|
inputs: Dict[str, Any] = {},
|
||||||
):
|
) -> None:
|
||||||
|
"""Add a task output to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task whose output is being stored
|
||||||
|
output: The output data from the task
|
||||||
|
task_index: Index of this task in the sequence
|
||||||
|
was_replayed: Whether this was from a replay
|
||||||
|
inputs: Optional input data that led to this output
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
sqlite3.Error: If there is an error saving to database
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
@@ -92,7 +124,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
Updates an existing row in the latest_kickoff_task_outputs table based on task_index.
|
Updates an existing row in the latest_kickoff_task_outputs table based on task_index.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
fields = []
|
fields = []
|
||||||
@@ -121,7 +153,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
|
|
||||||
def load(self) -> Optional[List[Dict[str, Any]]]:
|
def load(self) -> Optional[List[Dict[str, Any]]]:
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT *
|
SELECT *
|
||||||
@@ -157,7 +189,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
Deletes all rows from the latest_kickoff_task_outputs table.
|
Deletes all rows from the latest_kickoff_task_outputs table.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from crewai.utilities import Printer
|
from crewai.utilities import Printer
|
||||||
@@ -11,10 +13,26 @@ class LTMSQLiteStorage:
|
|||||||
An updated SQLite storage class for LTM data storage.
|
An updated SQLite storage class for LTM data storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db_path: Optional[str] = None) -> None:
|
def __init__(self, storage_path: Optional[Path] = None) -> None:
|
||||||
self.db_path = (
|
"""Initialize LTM SQLite storage.
|
||||||
db_path if db_path else f"{db_storage_path()}/latest_long_term_memories.db"
|
|
||||||
)
|
Args:
|
||||||
|
storage_path: Optional custom path for storage location
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionError: If storage path is not writable
|
||||||
|
OSError: If storage path cannot be created
|
||||||
|
"""
|
||||||
|
self.storage_path = storage_path if storage_path else Path(f"{db_storage_path()}/latest_long_term_memories.db")
|
||||||
|
|
||||||
|
# Validate storage path
|
||||||
|
try:
|
||||||
|
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not os.access(self.storage_path.parent, os.W_OK):
|
||||||
|
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
|
||||||
|
except OSError as e:
|
||||||
|
raise OSError(f"Failed to initialize storage path: {str(e)}")
|
||||||
|
|
||||||
self._printer: Printer = Printer()
|
self._printer: Printer = Printer()
|
||||||
self._initialize_db()
|
self._initialize_db()
|
||||||
|
|
||||||
@@ -23,7 +41,7 @@ class LTMSQLiteStorage:
|
|||||||
Initializes the SQLite database and creates LTM table
|
Initializes the SQLite database and creates LTM table
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
@@ -51,9 +69,20 @@ class LTMSQLiteStorage:
|
|||||||
datetime: str,
|
datetime: str,
|
||||||
score: Union[int, float],
|
score: Union[int, float],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Save a memory entry to long-term memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_description: Description of the task this memory relates to
|
||||||
|
metadata: Additional data to store with the memory
|
||||||
|
datetime: Timestamp for when this memory was created
|
||||||
|
score: Relevance score for this memory (higher is more relevant)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
sqlite3.Error: If there is an error saving to the database
|
||||||
|
"""
|
||||||
"""Saves data to the LTM table with error handling."""
|
"""Saves data to the LTM table with error handling."""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
@@ -74,7 +103,7 @@ class LTMSQLiteStorage:
|
|||||||
) -> Optional[List[Dict[str, Any]]]:
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
"""Queries the LTM table by task description with error handling."""
|
"""Queries the LTM table by task description with error handling."""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
f"""
|
f"""
|
||||||
@@ -109,7 +138,7 @@ class LTMSQLiteStorage:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Resets the LTM table with error handling."""
|
"""Resets the LTM table with error handling."""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(str(self.storage_path)) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("DELETE FROM long_term_memories")
|
cursor.execute("DELETE FROM long_term_memories")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|||||||
Reference in New Issue
Block a user