Apply automatic linting fixes to src directory

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-12 13:30:50 +00:00
parent 807dfe0558
commit ad1ea46bbb
160 changed files with 3218 additions and 3197 deletions

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any
from crewai.memory import (
EntityMemory,
@@ -12,13 +12,13 @@ from crewai.memory import (
class ContextualMemory:
def __init__(
self,
memory_config: Optional[Dict[str, Any]],
memory_config: dict[str, Any] | None,
stm: ShortTermMemory,
ltm: LongTermMemory,
em: EntityMemory,
um: UserMemory,
exm: ExternalMemory,
):
) -> None:
if memory_config is not None:
self.memory_provider = memory_config.get("provider")
else:
@@ -30,8 +30,7 @@ class ContextualMemory:
self.exm = exm
def build_context_for_task(self, task, context) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
"""Automatically builds a minimal, highly relevant set of contextual information
for a given task.
"""
query = f"{task.description} {context}".strip()
@@ -49,11 +48,9 @@ class ContextualMemory:
return "\n".join(filter(None, context))
def _fetch_stm_context(self, query) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
"""Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points.
"""
if self.stm is None:
return ""
@@ -62,16 +59,14 @@ class ContextualMemory:
[
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
for result in stm_results
]
],
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]:
"""
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
def _fetch_ltm_context(self, task) -> str | None:
"""Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
"""
if self.ltm is None:
return ""
@@ -90,8 +85,7 @@ class ContextualMemory:
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str:
"""
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
"""Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points.
"""
if self.em is None:
@@ -102,19 +96,20 @@ class ContextualMemory:
[
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
for result in em_results
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
], # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
)
return f"Entities:\n{formatted_results}" if em_results else ""
def _fetch_user_context(self, query: str) -> str:
"""
Fetches and formats relevant user information from User Memory.
"""Fetches and formats relevant user information from User Memory.
Args:
query (str): The search query to find relevant user memories.
Returns:
str: Formatted user memories as bullet points, or an empty string if none found.
"""
"""
if self.um is None:
return ""
@@ -128,12 +123,14 @@ class ContextualMemory:
return f"User memories/preferences:\n{formatted_memories}"
def _fetch_external_context(self, query: str) -> str:
"""
Fetches and formats relevant information from External Memory.
"""Fetches and formats relevant information from External Memory.
Args:
query (str): The search query to find relevant information.
Returns:
str: Formatted information as bullet points, or an empty string if none found.
"""
if self.exm is None:
return ""

View File

@@ -1,4 +1,3 @@
from typing import Optional
from pydantic import PrivateAttr
@@ -8,15 +7,14 @@ from crewai.memory.storage.rag_storage import RAGStorage
class EntityMemory(Memory):
"""
EntityMemory class for managing structured information about entities
"""EntityMemory class for managing structured information about entities
and their relationships using SQLite storage.
Inherits from the Memory class.
"""
_memory_provider: Optional[str] = PrivateAttr()
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider")
else:
@@ -26,8 +24,9 @@ class EntityMemory(Memory):
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
msg,
)
storage = Mem0Storage(type="entities", crew=crew)
else:
@@ -63,4 +62,5 @@ class EntityMemory(Memory):
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the entity memory: {e}")
msg = f"An error occurred while resetting the entity memory: {e}"
raise Exception(msg)

View File

@@ -5,7 +5,7 @@ class EntityMemoryItem:
type: str,
description: str,
relationships: str,
):
) -> None:
self.name = name
self.type = type
self.description = description

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any
from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory
@@ -9,41 +9,44 @@ if TYPE_CHECKING:
class ExternalMemory(Memory):
def __init__(self, storage: Optional[Storage] = None, **data: Any):
def __init__(self, storage: Storage | None = None, **data: Any) -> None:
super().__init__(storage=storage, **data)
@staticmethod
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
@staticmethod
def external_supported_storages() -> Dict[str, Any]:
def external_supported_storages() -> dict[str, Any]:
return {
"mem0": ExternalMemory._configure_mem0,
}
@staticmethod
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
if not embedder_config:
raise ValueError("embedder_config is required")
msg = "embedder_config is required"
raise ValueError(msg)
if "provider" not in embedder_config:
raise ValueError("embedder_config must include a 'provider' key")
msg = "embedder_config must include a 'provider' key"
raise ValueError(msg)
provider = embedder_config["provider"]
supported_storages = ExternalMemory.external_supported_storages()
if provider not in supported_storages:
raise ValueError(f"Provider {provider} not supported")
msg = f"Provider {provider} not supported"
raise ValueError(msg)
return supported_storages[provider](crew, embedder_config.get("config", {}))
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
metadata: dict[str, Any] | None = None,
agent: str | None = None,
) -> None:
"""Saves a value into the external storage."""
item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent)

View File

@@ -1,13 +1,13 @@
from typing import Any, Dict, Optional
from typing import Any
class ExternalMemoryItem:
def __init__(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
):
metadata: dict[str, Any] | None = None,
agent: str | None = None,
) -> None:
self.value = value
self.metadata = metadata
self.agent = agent

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
@@ -6,15 +6,14 @@ from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
class LongTermMemory(Memory):
"""
LongTermMemory class for managing cross runs data related to overall crew's
"""LongTermMemory class for managing cross runs data related to overall crew's
execution and performance.
Inherits from the Memory class and utilizes an instance of a class that
adheres to the Storage for data storage, specifically working with
LongTermMemoryItem instances.
"""
def __init__(self, storage=None, path=None):
def __init__(self, storage=None, path=None) -> None:
if not storage:
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage)
@@ -29,7 +28,7 @@ class LongTermMemory(Memory):
datetime=item.datetime,
)
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
def search(self, task: str, latest_n: int = 3) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def reset(self) -> None:

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union
from typing import Any
class LongTermMemoryItem:
@@ -8,9 +8,9 @@ class LongTermMemoryItem:
task: str,
expected_output: str,
datetime: str,
quality: Optional[Union[int, float]] = None,
metadata: Optional[Dict[str, Any]] = None,
):
quality: float | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
self.task = task
self.agent = agent
self.quality = quality

View File

@@ -1,26 +1,24 @@
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel
class Memory(BaseModel):
"""
Base class for memory, now supporting agent tags and generic metadata.
"""
"""Base class for memory, now supporting agent tags and generic metadata."""
embedder_config: Optional[Dict[str, Any]] = None
crew: Optional[Any] = None
embedder_config: dict[str, Any] | None = None
crew: Any | None = None
storage: Any
def __init__(self, storage: Any, **data: Any):
def __init__(self, storage: Any, **data: Any) -> None:
super().__init__(storage=storage, **data)
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
metadata: dict[str, Any] | None = None,
agent: str | None = None,
) -> None:
metadata = metadata or {}
if agent:
@@ -33,9 +31,9 @@ class Memory(BaseModel):
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
) -> list[Any]:
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
query=query, limit=limit, score_threshold=score_threshold,
)
def set_crew(self, crew: Any) -> "Memory":

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any
from pydantic import PrivateAttr
@@ -8,17 +8,16 @@ from crewai.memory.storage.rag_storage import RAGStorage
class ShortTermMemory(Memory):
"""
ShortTermMemory class for managing transient data related to immediate tasks
"""ShortTermMemory class for managing transient data related to immediate tasks
and interactions.
Inherits from the Memory class and utilizes an instance of a class that
adheres to the Storage for data storage, specifically working with
MemoryItem instances.
"""
_memory_provider: Optional[str] = PrivateAttr()
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider")
else:
@@ -28,8 +27,9 @@ class ShortTermMemory(Memory):
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
msg,
)
storage = Mem0Storage(type="short_term", crew=crew)
else:
@@ -49,8 +49,8 @@ class ShortTermMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
metadata: dict[str, Any] | None = None,
agent: str | None = None,
) -> None:
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
if self._memory_provider == "mem0":
@@ -65,13 +65,14 @@ class ShortTermMemory(Memory):
score_threshold: float = 0.35,
):
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
query=query, limit=limit, score_threshold=score_threshold,
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
def reset(self) -> None:
try:
self.storage.reset()
except Exception as e:
msg = f"An error occurred while resetting the short-term memory: {e}"
raise Exception(
f"An error occurred while resetting the short-term memory: {e}"
msg,
)

View File

@@ -1,13 +1,13 @@
from typing import Any, Dict, Optional
from typing import Any
class ShortTermMemoryItem:
def __init__(
self,
data: Any,
agent: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
):
agent: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
self.data = data
self.agent = agent
self.metadata = metadata if metadata is not None else {}

View File

@@ -1,11 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
class BaseRAGStorage(ABC):
"""
Base class for RAG-based Storage implementations.
"""
"""Base class for RAG-based Storage implementations."""
app: Any | None = None
@@ -13,9 +11,9 @@ class BaseRAGStorage(ABC):
self,
type: str,
allow_reset: bool = True,
embedder_config: Optional[Dict[str, Any]] = None,
embedder_config: dict[str, Any] | None = None,
crew: Any = None,
):
) -> None:
self.type = type
self.allow_reset = allow_reset
self.embedder_config = embedder_config
@@ -25,52 +23,44 @@ class BaseRAGStorage(ABC):
def _initialize_agents(self) -> str:
if self.crew:
return "_".join(
[self._sanitize_role(agent.role) for agent in self.crew.agents]
[self._sanitize_role(agent.role) for agent in self.crew.agents],
)
return ""
@abstractmethod
def _sanitize_role(self, role: str) -> str:
"""Sanitizes agent roles to ensure valid directory names."""
pass
@abstractmethod
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value with metadata to the storage."""
pass
@abstractmethod
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
filter: dict | None = None,
score_threshold: float = 0.35,
) -> List[Any]:
) -> list[Any]:
"""Search for entries in the storage."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the storage."""
pass
@abstractmethod
def _generate_embedding(
self, text: str, metadata: Optional[Dict[str, Any]] = None
self, text: str, metadata: dict[str, Any] | None = None,
) -> Any:
"""Generate an embedding for the given text and metadata."""
pass
@abstractmethod
def _initialize_app(self):
"""Initialize the vector db."""
pass
def setup_config(self, config: Dict[str, Any]):
def setup_config(self, config: dict[str, Any]) -> None:
"""Setup the config of the storage."""
pass
def initialize_client(self):
"""Initialize the client of the storage. This should setup the app and the db collection"""
pass
def initialize_client(self) -> None:
"""Initialize the client of the storage. This should setup the app and the db collection."""

View File

@@ -1,15 +1,15 @@
from typing import Any, Dict, List
from typing import Any
class Storage:
"""Abstract base class defining the storage interface"""
"""Abstract base class defining the storage interface."""
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
pass
def search(
self, query: str, limit: int, score_threshold: float
) -> Dict[str, Any] | List[Any]:
self, query: str, limit: int, score_threshold: float,
) -> dict[str, Any] | list[Any]:
return {}
def reset(self) -> None:

View File

@@ -2,7 +2,7 @@ import json
import logging
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
from crewai.task import Task
from crewai.utilities import Printer
@@ -14,12 +14,10 @@ logger = logging.getLogger(__name__)
class KickoffTaskOutputsSQLiteStorage:
"""
An updated SQLite storage class for kickoff task outputs storage.
"""
"""An updated SQLite storage class for kickoff task outputs storage."""
def __init__(
self, db_path: Optional[str] = None
self, db_path: str | None = None,
) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
@@ -37,6 +35,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises:
DatabaseOperationError: If database initialization fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -52,22 +51,22 @@ class KickoffTaskOutputsSQLiteStorage:
was_replayed BOOLEAN,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
""",
)
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
logger.error(error_msg)
logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e)
def add(
self,
task: Task,
output: Dict[str, Any],
output: dict[str, Any],
task_index: int,
was_replayed: bool = False,
inputs: Dict[str, Any] = {},
inputs: dict[str, Any] | None = None,
) -> None:
"""Add a new task output record to the database.
@@ -80,7 +79,10 @@ class KickoffTaskOutputsSQLiteStorage:
Raises:
DatabaseOperationError: If saving the task output fails due to SQLite errors.
"""
if inputs is None:
inputs = {}
try:
with sqlite3.connect(self.db_path) as conn:
conn.execute("BEGIN TRANSACTION")
@@ -103,7 +105,7 @@ class KickoffTaskOutputsSQLiteStorage:
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
logger.error(error_msg)
logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e)
def update(
@@ -123,6 +125,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises:
DatabaseOperationError: If updating the task output fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -136,7 +139,7 @@ class KickoffTaskOutputsSQLiteStorage:
values.append(
json.dumps(value, cls=CrewJSONEncoder)
if isinstance(value, dict)
else value
else value,
)
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec
@@ -149,10 +152,10 @@ class KickoffTaskOutputsSQLiteStorage:
logger.warning(f"No row found with task_index {task_index}. No update performed.")
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
logger.error(error_msg)
logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e)
def load(self) -> List[Dict[str, Any]]:
def load(self) -> list[dict[str, Any]]:
"""Load all task output records from the database.
Returns:
@@ -162,6 +165,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises:
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -190,7 +194,7 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
logger.error(error_msg)
logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e)
def delete_all(self) -> None:
@@ -201,6 +205,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises:
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -210,5 +215,5 @@ class KickoffTaskOutputsSQLiteStorage:
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
logger.error(error_msg)
logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e)

View File

@@ -1,19 +1,17 @@
import json
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any
from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path
class LTMSQLiteStorage:
"""
An updated SQLite storage class for LTM data storage.
"""
"""An updated SQLite storage class for LTM data storage."""
def __init__(
self, db_path: Optional[str] = None
self, db_path: str | None = None,
) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
@@ -24,10 +22,8 @@ class LTMSQLiteStorage:
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._initialize_db()
def _initialize_db(self):
"""
Initializes the SQLite database and creates LTM table
"""
def _initialize_db(self) -> None:
"""Initializes the SQLite database and creates LTM table."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
@@ -40,7 +36,7 @@ class LTMSQLiteStorage:
datetime TEXT,
score REAL
)
"""
""",
)
conn.commit()
@@ -53,9 +49,9 @@ class LTMSQLiteStorage:
def save(
self,
task_description: str,
metadata: Dict[str, Any],
metadata: dict[str, Any],
datetime: str,
score: Union[int, float],
score: float,
) -> None:
"""Saves data to the LTM table with error handling."""
try:
@@ -76,8 +72,8 @@ class LTMSQLiteStorage:
)
def load(
self, task_description: str, latest_n: int
) -> Optional[List[Dict[str, Any]]]:
self, task_description: str, latest_n: int,
) -> list[dict[str, Any]] | None:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -125,4 +121,3 @@ class LTMSQLiteStorage:
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
return None

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List
from typing import Any
from mem0 import Memory, MemoryClient
@@ -7,17 +7,15 @@ from crewai.memory.storage.interface import Storage
class Mem0Storage(Storage):
"""
Extends Storage to handle embedding and searching across entities using Mem0.
"""
"""Extends Storage to handle embedding and searching across entities using Mem0."""
def __init__(self, type, crew=None, config=None):
def __init__(self, type, crew=None, config=None) -> None:
super().__init__()
supported_types = ["user", "short_term", "long_term", "entities", "external"]
if type not in supported_types:
raise ValueError(
f"Invalid type '{type}' for Mem0Storage. Must be one of: "
+ ", ".join(supported_types)
+ ", ".join(supported_types),
)
self.memory_type = type
@@ -29,7 +27,8 @@ class Mem0Storage(Storage):
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
user_id = self._get_user_id()
if type == "user" and not user_id:
raise ValueError("User ID is required for user memory type")
msg = "User ID is required for user memory type"
raise ValueError(msg)
# API key in memory config overrides the environment variable
config = self._get_config()
@@ -42,23 +41,20 @@ class Mem0Storage(Storage):
if mem0_api_key:
if mem0_org_id and mem0_project_id:
self.memory = MemoryClient(
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id,
)
else:
self.memory = MemoryClient(api_key=mem0_api_key)
elif mem0_local_config and len(mem0_local_config):
self.memory = Memory.from_config(mem0_local_config)
else:
if mem0_local_config and len(mem0_local_config):
self.memory = Memory.from_config(mem0_local_config)
else:
self.memory = Memory()
self.memory = Memory()
def _sanitize_role(self, role: str) -> str:
"""
Sanitizes agent roles to ensure valid directory names.
"""
"""Sanitizes agent roles to ensure valid directory names."""
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
user_id = self._get_user_id()
agent_name = self._get_agent_name()
params = None
@@ -97,7 +93,7 @@ class Mem0Storage(Storage):
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
) -> list[Any]:
params = {"query": query, "limit": limit, "output_format": "v1.1"}
if user_id := self._get_user_id():
params["user_id"] = user_id
@@ -120,7 +116,7 @@ class Mem0Storage(Storage):
# automatically when the crew is created.
if isinstance(self.memory, Memory):
del params["metadata"], params["output_format"]
results = self.memory.search(**params)
return [r for r in results["results"] if r["score"] >= score_threshold]
@@ -133,12 +129,11 @@ class Mem0Storage(Storage):
agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return agents
return "_".join(agents)
def _get_config(self) -> Dict[str, Any]:
def _get_config(self) -> dict[str, Any]:
return self.config or getattr(self, "memory_config", {}).get("config", {}) or {}
def reset(self):
def reset(self) -> None:
if self.memory:
self.memory.reset()

View File

@@ -4,7 +4,7 @@ import logging
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional
from typing import Any
from chromadb.api import ClientAPI
@@ -32,16 +32,15 @@ def suppress_logging(
class RAGStorage(BaseRAGStorage):
"""
Extends Storage to handle embeddings for memory entries, improving
"""Extends Storage to handle embeddings for memory entries, improving
search efficiency.
"""
app: ClientAPI | None = None
def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
):
self, type, allow_reset=True, embedder_config=None, crew=None, path=None,
) -> None:
super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents]
@@ -55,11 +54,11 @@ class RAGStorage(BaseRAGStorage):
self.path = path
self._initialize_app()
def _set_embedder_config(self):
def _set_embedder_config(self) -> None:
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self):
def _initialize_app(self) -> None:
import chromadb
from chromadb.config import Settings
@@ -73,48 +72,44 @@ class RAGStorage(BaseRAGStorage):
try:
self.collection = self.app.get_collection(
name=self.type, embedding_function=self.embedder_config
name=self.type, embedding_function=self.embedder_config,
)
except Exception:
self.collection = self.app.create_collection(
name=self.type, embedding_function=self.embedder_config
name=self.type, embedding_function=self.embedder_config,
)
def _sanitize_role(self, role: str) -> str:
"""
Sanitizes agent roles to ensure valid directory names.
"""
"""Sanitizes agent roles to ensure valid directory names."""
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _build_storage_file_name(self, type: str, file_name: str) -> str:
"""
Ensures file name does not exceed max allowed by OS
"""
"""Ensures file name does not exceed max allowed by OS."""
base_path = f"{db_storage_path()}/{type}"
if len(file_name) > MAX_FILE_NAME_LENGTH:
logging.warning(
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters.",
)
file_name = file_name[:MAX_FILE_NAME_LENGTH]
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
try:
self._generate_embedding(value, metadata)
except Exception as e:
logging.error(f"Error during {self.type} save: {str(e)}")
logging.exception(f"Error during {self.type} save: {e!s}")
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
filter: dict | None = None,
score_threshold: float = 0.35,
) -> List[Any]:
) -> list[Any]:
if not hasattr(self, "app"):
self._initialize_app()
@@ -135,10 +130,10 @@ class RAGStorage(BaseRAGStorage):
return results
except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}")
logging.exception(f"Error during {self.type} search: {e!s}")
return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
@@ -160,8 +155,9 @@ class RAGStorage(BaseRAGStorage):
# Ignore this specific error
pass
else:
msg = f"An error occurred while resetting the {self.type} memory: {e}"
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
msg,
)
def _create_default_embedding_function(self):
@@ -170,5 +166,5 @@ class RAGStorage(BaseRAGStorage):
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small",
)

View File

@@ -1,18 +1,17 @@
import warnings
from typing import Any, Dict, Optional
from typing import Any
from crewai.memory.memory import Memory
class UserMemory(Memory):
"""
UserMemory class for handling user memory storage and retrieval.
"""UserMemory class for handling user memory storage and retrieval.
Inherits from the Memory class and utilizes an instance of a class that
adheres to the Storage for data storage, specifically working with
MemoryItem instances.
"""
def __init__(self, crew=None):
def __init__(self, crew=None) -> None:
warnings.warn(
"UserMemory is deprecated and will be removed in a future version. "
"Please use ExternalMemory instead.",
@@ -22,8 +21,9 @@ class UserMemory(Memory):
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
msg,
)
storage = Mem0Storage(type="user", crew=crew)
super().__init__(storage)
@@ -31,8 +31,8 @@ class UserMemory(Memory):
def save(
self,
value,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
metadata: dict[str, Any] | None = None,
agent: str | None = None,
) -> None:
# TODO: Change this function since we want to take care of the case where we save memories for the usr
data = f"Remember the details about the user: {value}"
@@ -44,15 +44,15 @@ class UserMemory(Memory):
limit: int = 3,
score_threshold: float = 0.35,
):
results = self.storage.search(
return self.storage.search(
query=query,
limit=limit,
score_threshold=score_threshold,
)
return results
def reset(self) -> None:
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the user memory: {e}")
msg = f"An error occurred while resetting the user memory: {e}"
raise Exception(msg)

View File

@@ -1,8 +1,8 @@
from typing import Any, Dict, Optional
from typing import Any
class UserMemoryItem:
def __init__(self, data: Any, user: str, metadata: Optional[Dict[str, Any]] = None):
def __init__(self, data: Any, user: str, metadata: dict[str, Any] | None = None) -> None:
self.data = data
self.user = user
self.metadata = metadata if metadata is not None else {}