mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-22 14:48:13 +00:00
Apply automatic linting fixes to src directory
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory import (
|
||||
EntityMemory,
|
||||
@@ -12,13 +12,13 @@ from crewai.memory import (
|
||||
class ContextualMemory:
|
||||
def __init__(
|
||||
self,
|
||||
memory_config: Optional[Dict[str, Any]],
|
||||
memory_config: dict[str, Any] | None,
|
||||
stm: ShortTermMemory,
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
um: UserMemory,
|
||||
exm: ExternalMemory,
|
||||
):
|
||||
) -> None:
|
||||
if memory_config is not None:
|
||||
self.memory_provider = memory_config.get("provider")
|
||||
else:
|
||||
@@ -30,8 +30,7 @@ class ContextualMemory:
|
||||
self.exm = exm
|
||||
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
"""Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
@@ -49,11 +48,9 @@ class ContextualMemory:
|
||||
return "\n".join(filter(None, context))
|
||||
|
||||
def _fetch_stm_context(self, query) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
"""Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
|
||||
if self.stm is None:
|
||||
return ""
|
||||
|
||||
@@ -62,16 +59,14 @@ class ContextualMemory:
|
||||
[
|
||||
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
|
||||
for result in stm_results
|
||||
]
|
||||
],
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
||||
"""
|
||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
def _fetch_ltm_context(self, task) -> str | None:
|
||||
"""Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
|
||||
if self.ltm is None:
|
||||
return ""
|
||||
|
||||
@@ -90,8 +85,7 @@ class ContextualMemory:
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query) -> str:
|
||||
"""
|
||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
"""Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
if self.em is None:
|
||||
@@ -102,19 +96,20 @@ class ContextualMemory:
|
||||
[
|
||||
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
|
||||
for result in em_results
|
||||
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
], # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
def _fetch_user_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches and formats relevant user information from User Memory.
|
||||
"""Fetches and formats relevant user information from User Memory.
|
||||
|
||||
Args:
|
||||
query (str): The search query to find relevant user memories.
|
||||
|
||||
Returns:
|
||||
str: Formatted user memories as bullet points, or an empty string if none found.
|
||||
"""
|
||||
|
||||
"""
|
||||
if self.um is None:
|
||||
return ""
|
||||
|
||||
@@ -128,12 +123,14 @@ class ContextualMemory:
|
||||
return f"User memories/preferences:\n{formatted_memories}"
|
||||
|
||||
def _fetch_external_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches and formats relevant information from External Memory.
|
||||
"""Fetches and formats relevant information from External Memory.
|
||||
|
||||
Args:
|
||||
query (str): The search query to find relevant information.
|
||||
|
||||
Returns:
|
||||
str: Formatted information as bullet points, or an empty string if none found.
|
||||
|
||||
"""
|
||||
if self.exm is None:
|
||||
return ""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
@@ -8,15 +7,14 @@ from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
"""
|
||||
EntityMemory class for managing structured information about entities
|
||||
"""EntityMemory class for managing structured information about entities
|
||||
and their relationships using SQLite storage.
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
@@ -26,8 +24,9 @@ class EntityMemory(Memory):
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
msg,
|
||||
)
|
||||
storage = Mem0Storage(type="entities", crew=crew)
|
||||
else:
|
||||
@@ -63,4 +62,5 @@ class EntityMemory(Memory):
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while resetting the entity memory: {e}")
|
||||
msg = f"An error occurred while resetting the entity memory: {e}"
|
||||
raise Exception(msg)
|
||||
|
||||
@@ -5,7 +5,7 @@ class EntityMemoryItem:
|
||||
type: str,
|
||||
description: str,
|
||||
relationships: str,
|
||||
):
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.description = description
|
||||
|
||||
23
src/crewai/memory/external/external_memory.py
vendored
23
src/crewai/memory/external/external_memory.py
vendored
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
@@ -9,41 +9,44 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ExternalMemory(Memory):
|
||||
def __init__(self, storage: Optional[Storage] = None, **data: Any):
|
||||
def __init__(self, storage: Storage | None = None, **data: Any) -> None:
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
@staticmethod
|
||||
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
|
||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
return Mem0Storage(type="external", crew=crew, config=config)
|
||||
|
||||
@staticmethod
|
||||
def external_supported_storages() -> Dict[str, Any]:
|
||||
def external_supported_storages() -> dict[str, Any]:
|
||||
return {
|
||||
"mem0": ExternalMemory._configure_mem0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
|
||||
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
|
||||
if not embedder_config:
|
||||
raise ValueError("embedder_config is required")
|
||||
msg = "embedder_config is required"
|
||||
raise ValueError(msg)
|
||||
|
||||
if "provider" not in embedder_config:
|
||||
raise ValueError("embedder_config must include a 'provider' key")
|
||||
msg = "embedder_config must include a 'provider' key"
|
||||
raise ValueError(msg)
|
||||
|
||||
provider = embedder_config["provider"]
|
||||
supported_storages = ExternalMemory.external_supported_storages()
|
||||
if provider not in supported_storages:
|
||||
raise ValueError(f"Provider {provider} not supported")
|
||||
msg = f"Provider {provider} not supported"
|
||||
raise ValueError(msg)
|
||||
|
||||
return supported_storages[provider](crew, embedder_config.get("config", {}))
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
"""Saves a value into the external storage."""
|
||||
item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ExternalMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
):
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
self.value = value
|
||||
self.metadata = metadata
|
||||
self.agent = agent
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
@@ -6,15 +6,14 @@ from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
class LongTermMemory(Memory):
|
||||
"""
|
||||
LongTermMemory class for managing cross runs data related to overall crew's
|
||||
"""LongTermMemory class for managing cross runs data related to overall crew's
|
||||
execution and performance.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, storage=None, path=None):
|
||||
def __init__(self, storage=None, path=None) -> None:
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
@@ -29,7 +28,7 @@ class LongTermMemory(Memory):
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
def search(self, task: str, latest_n: int = 3) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LongTermMemoryItem:
|
||||
@@ -8,9 +8,9 @@ class LongTermMemoryItem:
|
||||
task: str,
|
||||
expected_output: str,
|
||||
datetime: str,
|
||||
quality: Optional[Union[int, float]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
quality: float | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.quality = quality
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
"""Base class for memory, now supporting agent tags and generic metadata."""
|
||||
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
crew: Optional[Any] = None
|
||||
embedder_config: dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
|
||||
storage: Any
|
||||
|
||||
def __init__(self, storage: Any, **data: Any):
|
||||
def __init__(self, storage: Any, **data: Any) -> None:
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
if agent:
|
||||
@@ -33,9 +31,9 @@ class Memory(BaseModel):
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
query=query, limit=limit, score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
def set_crew(self, crew: Any) -> "Memory":
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
@@ -8,17 +8,16 @@ from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
"""
|
||||
ShortTermMemory class for managing transient data related to immediate tasks
|
||||
"""ShortTermMemory class for managing transient data related to immediate tasks
|
||||
and interactions.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
@@ -28,8 +27,9 @@ class ShortTermMemory(Memory):
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
msg,
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew)
|
||||
else:
|
||||
@@ -49,8 +49,8 @@ class ShortTermMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
if self._memory_provider == "mem0":
|
||||
@@ -65,13 +65,14 @@ class ShortTermMemory(Memory):
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
query=query, limit=limit, score_threshold=score_threshold,
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
msg = f"An error occurred while resetting the short-term memory: {e}"
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the short-term memory: {e}"
|
||||
msg,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ShortTermMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
data: Any,
|
||||
agent: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
agent: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.data = data
|
||||
self.agent = agent
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
"""
|
||||
Base class for RAG-based Storage implementations.
|
||||
"""
|
||||
"""Base class for RAG-based Storage implementations."""
|
||||
|
||||
app: Any | None = None
|
||||
|
||||
@@ -13,9 +11,9 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
) -> None:
|
||||
self.type = type
|
||||
self.allow_reset = allow_reset
|
||||
self.embedder_config = embedder_config
|
||||
@@ -25,52 +23,44 @@ class BaseRAGStorage(ABC):
|
||||
def _initialize_agents(self) -> str:
|
||||
if self.crew:
|
||||
return "_".join(
|
||||
[self._sanitize_role(agent.role) for agent in self.crew.agents]
|
||||
[self._sanitize_role(agent.role) for agent in self.crew.agents],
|
||||
)
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value with metadata to the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
"""Search for entries in the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _generate_embedding(
|
||||
self, text: str, metadata: Optional[Dict[str, Any]] = None
|
||||
self, text: str, metadata: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Generate an embedding for the given text and metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_app(self):
|
||||
"""Initialize the vector db."""
|
||||
pass
|
||||
|
||||
def setup_config(self, config: Dict[str, Any]):
|
||||
def setup_config(self, config: dict[str, Any]) -> None:
|
||||
"""Setup the config of the storage."""
|
||||
pass
|
||||
|
||||
def initialize_client(self):
|
||||
"""Initialize the client of the storage. This should setup the app and the db collection"""
|
||||
pass
|
||||
def initialize_client(self) -> None:
|
||||
"""Initialize the client of the storage. This should setup the app and the db collection."""
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Storage:
|
||||
"""Abstract base class defining the storage interface"""
|
||||
"""Abstract base class defining the storage interface."""
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def search(
|
||||
self, query: str, limit: int, score_threshold: float
|
||||
) -> Dict[str, Any] | List[Any]:
|
||||
self, query: str, limit: int, score_threshold: float,
|
||||
) -> dict[str, Any] | list[Any]:
|
||||
return {}
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.task import Task
|
||||
from crewai.utilities import Printer
|
||||
@@ -14,12 +14,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KickoffTaskOutputsSQLiteStorage:
|
||||
"""
|
||||
An updated SQLite storage class for kickoff task outputs storage.
|
||||
"""
|
||||
"""An updated SQLite storage class for kickoff task outputs storage."""
|
||||
|
||||
def __init__(
|
||||
self, db_path: Optional[str] = None
|
||||
self, db_path: str | None = None,
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
@@ -37,6 +35,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If database initialization fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -52,22 +51,22 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
was_replayed BOOLEAN,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def add(
|
||||
self,
|
||||
task: Task,
|
||||
output: Dict[str, Any],
|
||||
output: dict[str, Any],
|
||||
task_index: int,
|
||||
was_replayed: bool = False,
|
||||
inputs: Dict[str, Any] = {},
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Add a new task output record to the database.
|
||||
|
||||
@@ -80,7 +79,10 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If saving the task output fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
@@ -103,7 +105,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def update(
|
||||
@@ -123,6 +125,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If updating the task output fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -136,7 +139,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
values.append(
|
||||
json.dumps(value, cls=CrewJSONEncoder)
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
else value,
|
||||
)
|
||||
|
||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec
|
||||
@@ -149,10 +152,10 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
logger.warning(f"No row found with task_index {task_index}. No update performed.")
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def load(self) -> List[Dict[str, Any]]:
|
||||
def load(self) -> list[dict[str, Any]]:
|
||||
"""Load all task output records from the database.
|
||||
|
||||
Returns:
|
||||
@@ -162,6 +165,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -190,7 +194,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def delete_all(self) -> None:
|
||||
@@ -201,6 +205,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -210,5 +215,5 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class LTMSQLiteStorage:
|
||||
"""
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
"""An updated SQLite storage class for LTM data storage."""
|
||||
|
||||
def __init__(
|
||||
self, db_path: Optional[str] = None
|
||||
self, db_path: str | None = None,
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
@@ -24,10 +22,8 @@ class LTMSQLiteStorage:
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""
|
||||
Initializes the SQLite database and creates LTM table
|
||||
"""
|
||||
def _initialize_db(self) -> None:
|
||||
"""Initializes the SQLite database and creates LTM table."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
@@ -40,7 +36,7 @@ class LTMSQLiteStorage:
|
||||
datetime TEXT,
|
||||
score REAL
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -53,9 +49,9 @@ class LTMSQLiteStorage:
|
||||
def save(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: Dict[str, Any],
|
||||
metadata: dict[str, Any],
|
||||
datetime: str,
|
||||
score: Union[int, float],
|
||||
score: float,
|
||||
) -> None:
|
||||
"""Saves data to the LTM table with error handling."""
|
||||
try:
|
||||
@@ -76,8 +72,8 @@ class LTMSQLiteStorage:
|
||||
)
|
||||
|
||||
def load(
|
||||
self, task_description: str, latest_n: int
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
self, task_description: str, latest_n: int,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Queries the LTM table by task description with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -125,4 +121,3 @@ class LTMSQLiteStorage:
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from mem0 import Memory, MemoryClient
|
||||
|
||||
@@ -7,17 +7,15 @@ from crewai.memory.storage.interface import Storage
|
||||
|
||||
|
||||
class Mem0Storage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embedding and searching across entities using Mem0.
|
||||
"""
|
||||
"""Extends Storage to handle embedding and searching across entities using Mem0."""
|
||||
|
||||
def __init__(self, type, crew=None, config=None):
|
||||
def __init__(self, type, crew=None, config=None) -> None:
|
||||
super().__init__()
|
||||
supported_types = ["user", "short_term", "long_term", "entities", "external"]
|
||||
if type not in supported_types:
|
||||
raise ValueError(
|
||||
f"Invalid type '{type}' for Mem0Storage. Must be one of: "
|
||||
+ ", ".join(supported_types)
|
||||
+ ", ".join(supported_types),
|
||||
)
|
||||
|
||||
self.memory_type = type
|
||||
@@ -29,7 +27,8 @@ class Mem0Storage(Storage):
|
||||
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
|
||||
user_id = self._get_user_id()
|
||||
if type == "user" and not user_id:
|
||||
raise ValueError("User ID is required for user memory type")
|
||||
msg = "User ID is required for user memory type"
|
||||
raise ValueError(msg)
|
||||
|
||||
# API key in memory config overrides the environment variable
|
||||
config = self._get_config()
|
||||
@@ -42,23 +41,20 @@ class Mem0Storage(Storage):
|
||||
if mem0_api_key:
|
||||
if mem0_org_id and mem0_project_id:
|
||||
self.memory = MemoryClient(
|
||||
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
|
||||
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id,
|
||||
)
|
||||
else:
|
||||
self.memory = MemoryClient(api_key=mem0_api_key)
|
||||
elif mem0_local_config and len(mem0_local_config):
|
||||
self.memory = Memory.from_config(mem0_local_config)
|
||||
else:
|
||||
if mem0_local_config and len(mem0_local_config):
|
||||
self.memory = Memory.from_config(mem0_local_config)
|
||||
else:
|
||||
self.memory = Memory()
|
||||
self.memory = Memory()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
user_id = self._get_user_id()
|
||||
agent_name = self._get_agent_name()
|
||||
params = None
|
||||
@@ -97,7 +93,7 @@ class Mem0Storage(Storage):
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
params = {"query": query, "limit": limit, "output_format": "v1.1"}
|
||||
if user_id := self._get_user_id():
|
||||
params["user_id"] = user_id
|
||||
@@ -120,7 +116,7 @@ class Mem0Storage(Storage):
|
||||
# automatically when the crew is created.
|
||||
if isinstance(self.memory, Memory):
|
||||
del params["metadata"], params["output_format"]
|
||||
|
||||
|
||||
results = self.memory.search(**params)
|
||||
return [r for r in results["results"] if r["score"] >= score_threshold]
|
||||
|
||||
@@ -133,12 +129,11 @@ class Mem0Storage(Storage):
|
||||
|
||||
agents = self.crew.agents
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return agents
|
||||
return "_".join(agents)
|
||||
|
||||
def _get_config(self) -> Dict[str, Any]:
|
||||
def _get_config(self) -> dict[str, Any]:
|
||||
return self.config or getattr(self, "memory_config", {}).get("config", {}) or {}
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
if self.memory:
|
||||
self.memory.reset()
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from chromadb.api import ClientAPI
|
||||
|
||||
@@ -32,16 +32,15 @@ def suppress_logging(
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
"""Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
app: ClientAPI | None = None
|
||||
|
||||
def __init__(
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
):
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
@@ -55,11 +54,11 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.path = path
|
||||
self._initialize_app()
|
||||
|
||||
def _set_embedder_config(self):
|
||||
def _set_embedder_config(self) -> None:
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _initialize_app(self):
|
||||
def _initialize_app(self) -> None:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
@@ -73,48 +72,44 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
try:
|
||||
self.collection = self.app.get_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
name=self.type, embedding_function=self.embedder_config,
|
||||
)
|
||||
except Exception:
|
||||
self.collection = self.app.create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
name=self.type, embedding_function=self.embedder_config,
|
||||
)
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
"""Ensures file name does not exceed max allowed by OS."""
|
||||
base_path = f"{db_storage_path()}/{type}"
|
||||
|
||||
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||
logging.warning(
|
||||
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
||||
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters.",
|
||||
)
|
||||
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
try:
|
||||
self._generate_embedding(value, metadata)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {str(e)}")
|
||||
logging.exception(f"Error during {self.type} save: {e!s}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
|
||||
@@ -135,10 +130,10 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
logging.exception(f"Error during {self.type} search: {e!s}")
|
||||
return []
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
|
||||
def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
|
||||
@@ -160,8 +155,9 @@ class RAGStorage(BaseRAGStorage):
|
||||
# Ignore this specific error
|
||||
pass
|
||||
else:
|
||||
msg = f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
msg,
|
||||
)
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
@@ -170,5 +166,5 @@ class RAGStorage(BaseRAGStorage):
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small",
|
||||
)
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory.memory import Memory
|
||||
|
||||
|
||||
class UserMemory(Memory):
|
||||
"""
|
||||
UserMemory class for handling user memory storage and retrieval.
|
||||
"""UserMemory class for handling user memory storage and retrieval.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None):
|
||||
def __init__(self, crew=None) -> None:
|
||||
warnings.warn(
|
||||
"UserMemory is deprecated and will be removed in a future version. "
|
||||
"Please use ExternalMemory instead.",
|
||||
@@ -22,8 +21,9 @@ class UserMemory(Memory):
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
msg,
|
||||
)
|
||||
storage = Mem0Storage(type="user", crew=crew)
|
||||
super().__init__(storage)
|
||||
@@ -31,8 +31,8 @@ class UserMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
# TODO: Change this function since we want to take care of the case where we save memories for the usr
|
||||
data = f"Remember the details about the user: {value}"
|
||||
@@ -44,15 +44,15 @@ class UserMemory(Memory):
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
results = self.storage.search(
|
||||
return self.storage.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while resetting the user memory: {e}")
|
||||
msg = f"An error occurred while resetting the user memory: {e}"
|
||||
raise Exception(msg)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class UserMemoryItem:
|
||||
def __init__(self, data: Any, user: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, data: Any, user: str, metadata: dict[str, Any] | None = None) -> None:
|
||||
self.data = data
|
||||
self.user = user
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
|
||||
Reference in New Issue
Block a user